LUCENE-6631 - added document classification api and impls

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1709522 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-10-20 07:36:41 +00:00
parent 0fe5ab3b9b
commit 71cea88773
12 changed files with 1031 additions and 92 deletions

View File

@ -2,7 +2,6 @@ package org.apache.lucene.classification;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -81,38 +80,17 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
String[] tokenizedDoc = tokenizeDoc(inputDocument);
String[] tokenizedText = tokenize(inputDocument);
List<ClassificationResult<BytesRef>> dataList = calculateLogLikelihood(tokenizedDoc);
List<ClassificationResult<BytesRef>> assignedClasses = calculateLogLikelihood(tokenizedText);
// normalization
// The values transforms to a 0-1 range
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
if (!dataList.isEmpty()) {
Collections.sort(dataList);
// this is a negative number closest to 0 = a
double smax = dataList.get(0).getScore();
double sumLog = 0;
// log(sum(exp(x_n-a)))
for (ClassificationResult<BytesRef> cr : dataList) {
// getScore-smax <=0 (both negative, smax is the smallest abs()
sumLog += Math.exp(cr.getScore() - smax);
}
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
double loga = smax;
loga += Math.log(sumLog);
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
for (ClassificationResult<BytesRef> cr : dataList) {
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
}
}
return returnList;
ArrayList<ClassificationResult<BytesRef>> asignedClassesNorm = super.normClassificationResults(assignedClasses);
return asignedClassesNorm;
}
private List<ClassificationResult<BytesRef>> calculateLogLikelihood(String[] tokenizedDoc) throws IOException {
private List<ClassificationResult<BytesRef>> calculateLogLikelihood(String[] tokenizedText) throws IOException {
// initialize the return List
ArrayList<ClassificationResult<BytesRef>> ret = new ArrayList<>();
for (BytesRef cclass : cclasses) {
@ -120,7 +98,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
ret.add(cr);
}
// for each word
for (String word : tokenizedDoc) {
for (String word : tokenizedText) {
// search with text:word for all class:c
Map<BytesRef, Integer> hitsInClasses = getWordFreqForClassess(word);
// for each class

View File

@ -48,12 +48,12 @@ import org.apache.lucene.util.BytesRef;
*/
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
private final MoreLikeThis mlt;
private final String[] textFieldNames;
private final String classFieldName;
private final IndexSearcher indexSearcher;
private final int k;
private final Query query;
protected final MoreLikeThis mlt;
protected final String[] textFieldNames;
protected final String classFieldName;
protected final IndexSearcher indexSearcher;
protected final int k;
protected final Query query;
/**
* Creates a {@link KNearestNeighborClassifier}.
@ -159,7 +159,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
}
//ranking of classes must be taken in consideration
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
Map<BytesRef, Integer> classCounts = new HashMap<>();
Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
float maxScore = topDocs.getMaxScore();

View File

@ -85,8 +85,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* @param analyzer an {@link Analyzer} used to analyze unseen text
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used
* @param classFieldName the name of the field used as the output for the classifier
* @param textFieldNames the name of the fields used as the inputs for the classifier
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
* as the returned class will be a token indexed for this field
* @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
*/
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
this.leafReader = leafReader;
@ -102,16 +103,16 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
*/
@Override
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(inputDocument);
ClassificationResult<BytesRef> retval = null;
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(inputDocument);
ClassificationResult<BytesRef> assignedClass = null;
double maxscore = -Double.MAX_VALUE;
for (ClassificationResult<BytesRef> element : doclist) {
if (element.getScore() > maxscore) {
retval = element;
maxscore = element.getScore();
for (ClassificationResult<BytesRef> c : assignedClasses) {
if (c.getScore() > maxscore) {
assignedClass = c;
maxscore = c.getScore();
}
}
return retval;
return assignedClass;
}
/**
@ -119,9 +120,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
Collections.sort(doclist);
return doclist;
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
Collections.sort(assignedClasses);
return assignedClasses;
}
/**
@ -129,9 +130,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
Collections.sort(doclist);
return doclist.subList(0, max);
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
Collections.sort(assignedClasses);
return assignedClasses.subList(0, max);
}
/**
@ -141,46 +142,26 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* @throws IOException if assigning probabilities fails
*/
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
Terms terms = MultiFields.getTerms(leafReader, classFieldName);
TermsEnum termsEnum = terms.iterator();
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
TermsEnum classesEnum = classes.iterator();
BytesRef next;
String[] tokenizedDoc = tokenizeDoc(inputDocument);
String[] tokenizedText = tokenize(inputDocument);
int docsWithClassSize = countDocsWithClass();
while ((next = termsEnum.next()) != null) {
while ((next = classesEnum.next()) != null) {
if (next.length > 0) {
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
next = BytesRef.deepCopyOf(next);
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
dataList.add(new ClassificationResult<>(next, clVal));
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedText, next, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(next, clVal));
}
}
// normalization; the values transforms to a 0-1 range
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
if (!dataList.isEmpty()) {
Collections.sort(dataList);
// this is a negative number closest to 0 = a
double smax = dataList.get(0).getScore();
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
double sumLog = 0;
// log(sum(exp(x_n-a)))
for (ClassificationResult<BytesRef> cr : dataList) {
// getScore-smax <=0 (both negative, smax is the smallest abs()
sumLog += Math.exp(cr.getScore() - smax);
}
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
double loga = smax;
loga += Math.log(sumLog);
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
for (ClassificationResult<BytesRef> cr : dataList) {
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
}
}
return returnList;
return assignedClassesNorm;
}
/**
@ -192,15 +173,15 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
protected int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
BooleanQuery.Builder q = new BooleanQuery.Builder();
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
if (query != null) {
q.add(query, BooleanClause.Occur.MUST);
}
indexSearcher.search(q.build(),
totalHitCountCollector);
docCount = totalHitCountCollector.getTotalHits();
classQueryCountCollector);
docCount = classQueryCountCollector.getTotalHits();
}
return docCount;
}
@ -208,14 +189,14 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
/**
* tokenize a <code>String</code> on this classifier's text fields and analyzer
*
* @param doc the <code>String</code> representing an input text (to be classified)
* @param text the <code>String</code> representing an input text (to be classified)
* @return a <code>String</code> array of the resulting tokens
* @throws IOException if tokenization fails
*/
protected String[] tokenizeDoc(String doc) throws IOException {
protected String[] tokenize(String text) throws IOException {
Collection<String> result = new LinkedList<>();
for (String textFieldName : textFieldNames) {
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
@ -227,18 +208,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return result.toArray(new String[result.size()]);
}
private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c, int docsWithClassSize) throws IOException {
private double calculateLogLikelihood(String[] tokenizedText, BytesRef c, int docsWithClass) throws IOException {
// for each word
double result = 0d;
for (String word : tokenizedDoc) {
for (String word : tokenizedText) {
// search with text:word AND class:c
int hits = getWordFreqForClass(word, c);
int hits = getWordFreqForClass(word,c);
// num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = getTextTermFreqForClass(c) + docsWithClassSize;
double den = getTextTermFreqForClass(c) + docsWithClass;
// P(w|c) = num/den
double wordProbability = num / den;
@ -249,6 +230,12 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return result;
}
/**
* Returns the average number of unique terms times the number of docs belonging to the input class
* @param c the class
* @return the average number of unique terms
* @throws IOException if a low level I/O problem happens
*/
private double getTextTermFreqForClass(BytesRef c) throws IOException {
double avgNumberOfUniqueTerms = 0;
for (String textFieldName : textFieldNames) {
@ -260,6 +247,14 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
}
/**
* Returns the number of documents of the input class ( from the whole index or from a subset)
* that contains the word ( in a specific field or in all the fields if no one selected)
* @param word the token produced by the analyzer
* @param c the class
* @return the number of documents of the input class
* @throws IOException if a low level I/O problem happens
*/
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
@ -283,4 +278,36 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
private int docCount(BytesRef countedClass) throws IOException {
return leafReader.docFreq(new Term(classFieldName, countedClass));
}
/**
* Normalize the classification results based on the max score available
* @param assignedClasses the list of assigned classes
* @return the normalized results
*/
protected ArrayList<ClassificationResult<BytesRef>> normClassificationResults(List<ClassificationResult<BytesRef>> assignedClasses) {
// normalization; the values transforms to a 0-1 range
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
if (!assignedClasses.isEmpty()) {
Collections.sort(assignedClasses);
// this is a negative number closest to 0 = a
double smax = assignedClasses.get(0).getScore();
double sumLog = 0;
// log(sum(exp(x_n-a)))
for (ClassificationResult<BytesRef> cr : assignedClasses) {
// getScore-smax <=0 (both negative, smax is the smallest abs()
sumLog += Math.exp(cr.getScore() - smax);
}
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
double loga = smax;
loga += Math.log(sumLog);
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
for (ClassificationResult<BytesRef> cr : assignedClasses) {
double scoreDiff = cr.getScore() - loga;
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff)));
}
}
return returnList;
}
}

View File

@ -0,0 +1,61 @@
package org.apache.lucene.classification.document;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.List;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.document.Document;
/**
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
* <code>T</code> to a {@link org.apache.lucene.document.Document}s
*
* @lucene.experimental
*/
public interface DocumentClassifier<T> {
/**
* Assign a class (with score) to the given {@link org.apache.lucene.document.Document}
*
* @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
* @return a {@link org.apache.lucene.classification.ClassificationResult} holding assigned class of type <code>T</code> and score
* @throws java.io.IOException If there is a low-level I/O error.
*/
ClassificationResult<T> assignClass(Document document) throws IOException;
/**
* Get all the classes (sorted by score, descending) assigned to the given {@link org.apache.lucene.document.Document}.
*
* @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
* @return the whole list of {@link org.apache.lucene.classification.ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
* @throws java.io.IOException If there is a low-level I/O error.
*/
List<ClassificationResult<T>> getClasses(Document document) throws IOException;
/**
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
*
* @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
* @param max the number of return list elements
* @return the whole list of {@link org.apache.lucene.classification.ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
* @throws java.io.IOException If there is a low-level I/O error.
*/
List<ClassificationResult<T>> getClasses(Document document, int max) throws IOException;
}

View File

@ -0,0 +1,146 @@
package org.apache.lucene.classification.document;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.io.StringReader;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.KNearestNeighborClassifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;
/**
* A k-Nearest Neighbor Document classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors</code>) based
* on {@link org.apache.lucene.queries.mlt.MoreLikeThis} .
*
* @lucene.experimental
*/
public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifier implements DocumentClassifier<BytesRef> {
protected Map<String, Analyzer> field2analyzer;
/**
* Creates a {@link KNearestNeighborClassifier}.
*
* @param leafReader the reader on the index to be used for classification
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used
* @param k the no. of docs to select in the MLT results to find the nearest neighbor
* @param minDocsFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq} parameter
* @param minTermFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq} parameter
* @param classFieldName the name of the field used as the output for the classifier
* @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer}
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
*/
public KNearestNeighborDocumentClassifier(LeafReader leafReader, Similarity similarity, Query query, int k, int minDocsFreq,
int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
super(leafReader,similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
this.field2analyzer = field2analyzer;
}
/**
* {@inheritDoc}
*/
@Override
public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
TopDocs knnResults = knnSearch(document);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
ClassificationResult<BytesRef> assignedClass = null;
double maxscore = -Double.MAX_VALUE;
for (ClassificationResult<BytesRef> cl : assignedClasses) {
if (cl.getScore() > maxscore) {
assignedClass = cl;
maxscore = cl.getScore();
}
}
return assignedClass;
}
/**
* {@inheritDoc}
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
TopDocs knnResults = knnSearch(document);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
return assignedClasses;
}
/**
* {@inheritDoc}
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
TopDocs knnResults = knnSearch(document);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
return assignedClasses.subList(0, max);
}
/**
* Returns the top k results from a More Like This query based on the input document
*
* @param document the document to use for More Like This search
* @return the top results for the MLT query
* @throws IOException If there is a low-level I/O error
*/
private TopDocs knnSearch(Document document) throws IOException {
BooleanQuery.Builder mltQuery = new BooleanQuery.Builder();
for (String fieldName : textFieldNames) {
String boost = null;
if (fieldName.contains("^")) {
String[] field2boost = fieldName.split("\\^");
fieldName = field2boost[0];
boost = field2boost[1];
}
String[] fieldValues = document.getValues(fieldName);
if (boost != null) {
mlt.setBoost(true);
mlt.setBoostFactor(Float.parseFloat(boost));
}
mlt.setAnalyzer(field2analyzer.get(fieldName));
for (String fieldContent : fieldValues) {
mltQuery.add(new BooleanClause(mlt.like(fieldName, new StringReader(fieldContent)), BooleanClause.Occur.SHOULD));
}
mlt.setBoost(false);
}
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
if (query != null) {
mltQuery.add(query, BooleanClause.Occur.MUST);
}
return indexSearcher.search(mltQuery.build(), k);
}
}

View File

@ -0,0 +1,289 @@
package org.apache.lucene.classification.document;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
/**
* A simplistic Lucene based NaiveBayes classifier, see {@code http://en.wikipedia.org/wiki/Naive_Bayes_classifier}
*
* @lucene.experimental
*/
public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifier implements DocumentClassifier<BytesRef> {
/**
* {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing document fields
*/
protected Map<String, Analyzer> field2analyzer;
/**
* Creates a new NaiveBayes classifier.
*
* @param leafReader the reader on the index to be used for classification
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
* as the returned class will be a token indexed for this field
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
*/
public SimpleNaiveBayesDocumentClassifier(LeafReader leafReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
super(leafReader, null, query, classFieldName, textFieldNames);
this.field2analyzer = field2analyzer;
}
/**
* {@inheritDoc}
*/
@Override
public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
ClassificationResult<BytesRef> assignedClass = null;
double maxscore = -Double.MAX_VALUE;
for (ClassificationResult<BytesRef> c : assignedClasses) {
if (c.getScore() > maxscore) {
assignedClass = c;
maxscore = c.getScore();
}
}
return assignedClass;
}
/**
* {@inheritDoc}
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
Collections.sort(assignedClasses);
return assignedClasses;
}
/**
* {@inheritDoc}
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
Collections.sort(assignedClasses);
return assignedClasses.subList(0, max);
}
private List<ClassificationResult<BytesRef>> assignNormClasses(Document inputDocument) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
Map<String, Float> fieldName2boost = new LinkedHashMap<>();
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
TermsEnum classesEnum = classes.iterator();
BytesRef c;
analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost);
int docsWithClassSize = countDocsWithClass();
while ((c = classesEnum.next()) != null) {
double classScore = 0;
for (String fieldName : textFieldNames) {
List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName);
double fieldScore = 0;
for (String[] fieldTokensArray : tokensArrays) {
fieldScore += calculateLogPrior(c, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, c, docsWithClassSize) * fieldName2boost.get(fieldName);
}
classScore += fieldScore;
}
assignedClasses.add(new ClassificationResult<>(BytesRef.deepCopyOf(c), classScore));
}
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
return assignedClassesNorm;
}
/**
* This methods performs the analysis for the seed document and extract the boosts if present.
* This is done only one time for the Seed Document.
*
* @param inputDocument the seed unseen document
* @param fieldName2tokensArray a map that associated to a field name the list of token arrays for all its values
* @param fieldName2boost a map that associates the boost to the field
* @throws IOException If there is a low-level I/O error
*/
private void analyzeSeedDocument(Document inputDocument, Map<String, List<String[]>> fieldName2tokensArray, Map<String, Float> fieldName2boost) throws IOException {
for (int i = 0; i < textFieldNames.length; i++) {
String fieldName = textFieldNames[i];
float boost = 1;
List<String[]> tokenizedValues = new LinkedList<>();
if (fieldName.contains("^")) {
String[] field2boost = fieldName.split("\\^");
fieldName = field2boost[0];
boost = Float.parseFloat(field2boost[1]);
}
Field[] fieldValues = inputDocument.getFields(fieldName);
for (Field fieldValue : fieldValues) {
TokenStream fieldTokens = fieldValue.tokenStream(field2analyzer.get(fieldName), null);
String[] fieldTokensArray = getTokenArray(fieldTokens);
tokenizedValues.add(fieldTokensArray);
}
fieldName2tokensArray.put(fieldName, tokenizedValues);
fieldName2boost.put(fieldName, boost);
textFieldNames[i] = fieldName;
}
}
/**
* Counts the number of documents in the index having at least a value for the 'class' field
*
* @return the no. of documents having a value for the 'class' field
* @throws java.io.IOException If accessing to term vectors or search fails
*/
protected int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
BooleanQuery.Builder q = new BooleanQuery.Builder();
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
if (query != null) {
q.add(query, BooleanClause.Occur.MUST);
}
indexSearcher.search(q.build(),
classQueryCountCollector);
docCount = classQueryCountCollector.getTotalHits();
}
return docCount;
}
/**
* Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input
*
* @param tokenizedText the tokenized content of a field
* @return a {@code String} array of the resulting tokens
* @throws java.io.IOException If tokenization fails because there is a low-level I/O error
*/
protected String[] getTokenArray(TokenStream tokenizedText) throws IOException {
Collection<String> tokens = new LinkedList<>();
CharTermAttribute charTermAttribute = tokenizedText.addAttribute(CharTermAttribute.class);
tokenizedText.reset();
while (tokenizedText.incrementToken()) {
tokens.add(charTermAttribute.toString());
}
tokenizedText.end();
tokenizedText.close();
return tokens.toArray(new String[tokens.size()]);
}
/**
* @param tokenizedText the tokenized content of a field
* @param fieldName the input field name
* @param c the class to calculate the score of
* @param docsWithClass the total number of docs that have a class
* @return a normalized score for the class
* @throws IOException If there is a low-level I/O error
*/
private double calculateLogLikelihood(String[] tokenizedText, String fieldName, BytesRef c, int docsWithClass) throws IOException {
// for each word
double result = 0d;
for (String word : tokenizedText) {
// search with text:word AND class:c
int hits = getWordFreqForClass(word, fieldName, c);
// num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = getTextTermFreqForClass(c, fieldName) + docsWithClass;
// P(w|c) = num/den
double wordProbability = num / den;
result += Math.log(wordProbability);
}
// log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c))
double normScore = result / (tokenizedText.length); // this is normalized because if not, long text fields will always be more important than short fields
return normScore;
}
/**
* Returns the average number of unique terms times the number of docs belonging to the input class
*
* @param c the class
* @return the average number of unique terms
* @throws java.io.IOException If there is a low-level I/O error
*/
private double getTextTermFreqForClass(BytesRef c, String fieldName) throws IOException {
double avgNumberOfUniqueTerms;
Terms terms = MultiFields.getTerms(leafReader, fieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
int docsWithC = leafReader.docFreq(new Term(classFieldName, c));
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
}
/**
* Returns the number of documents of the input class ( from the whole index or from a subset)
* that contains the word ( in a specific field or in all the fields if no one selected)
*
* @param word the token produced by the analyzer
* @param fieldName the field the word is coming from
* @param c the class
* @return number of documents of the input class
* @throws java.io.IOException If there is a low-level I/O error
*/
private int getWordFreqForClass(String word, String fieldName, BytesRef c) throws IOException {
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD));
booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST);
}
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
indexSearcher.search(booleanQuery.build(), totalHitCountCollector);
return totalHitCountCollector.getTotalHits();
}
private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
}
private int docCount(BytesRef countedClass) throws IOException {
return leafReader.docFreq(new Term(classFieldName, countedClass));
}
}

View File

@ -0,0 +1,7 @@
/**
* Uses already seen data (the indexed documents) to classify new documents.
* <p>
* Currently contains a (simplistic) Naive Bayes classifier and a k-Nearest
* Neighbor classifier.
*/
package org.apache.lucene.classification.document;

View File

@ -16,7 +16,7 @@
*/
/**
* Uses already seen data (the indexed documents) to classify new documents.
* Uses already seen data (the indexed documents) to classify an input ( can be simple text or a structured document).
* <p>
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
* Neighbor classifier and a Perceptron based classifier.

View File

@ -57,8 +57,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
protected static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
protected RandomIndexWriter indexWriter;
private Directory dir;
private FieldType ft;
protected Directory dir;
protected FieldType ft;
protected String textFieldName;
protected String categoryFieldName;

View File

@ -0,0 +1,259 @@
package org.apache.lucene.classification.document;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.en.EnglishAnalyzer;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.ClassificationTestBase;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.util.BytesRef;
import org.junit.Before;
/**
* Base class for testing {@link org.apache.lucene.classification.Classifier}s
*/
public abstract class DocumentClassificationTestBase<T> extends ClassificationTestBase {
protected static final BytesRef VIDEOGAME_RESULT = new BytesRef("videogames");
protected static final BytesRef VIDEOGAME_ANALYZED_RESULT = new BytesRef("videogam");
protected static final BytesRef BATMAN_RESULT = new BytesRef("batman");
protected String titleFieldName = "title";
protected String authorFieldName = "author";
protected Analyzer analyzer;
protected Map<String, Analyzer> field2analyzer;
protected LeafReader leafReader;
@Before
public void init() throws IOException {
analyzer = new EnglishAnalyzer();
field2analyzer = new LinkedHashMap<>();
field2analyzer.put(textFieldName, analyzer);
field2analyzer.put(titleFieldName, analyzer);
field2analyzer.put(authorFieldName, analyzer);
leafReader = populateDocumentClassificationIndex(analyzer);
}
protected double checkCorrectDocumentClassification(DocumentClassifier<T> classifier, Document inputDoc, T expectedResult) throws Exception {
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
double score = classificationResult.getScore();
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
return score;
}
protected LeafReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException {
indexWriter.close();
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
indexWriter.commit();
String text;
String title;
String author;
Document doc = new Document();
title = "Video games are an economic business";
text = "Video games have become an art form and an industry. The video game industry is of increasing" +
" commercial importance, with growth driven particularly by the emerging Asian markets and mobile games." +
" As of 2015, video games generated sales of USD 74 billion annually worldwide, and were the third-largest" +
" segment in the U.S. entertainment market, behind broadcast and cable TV.";
author = "Ign";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(categoryFieldName, "videogames", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc);
doc = new Document();
title = "Video games: the definition of fun on PC and consoles";
text = "A video game is an electronic game that involves human interaction with a user interface to generate" +
" visual feedback on a video device. The word video in video game traditionally referred to a raster display device," +
"[1] but it now implies any type of display device that can produce two- or three-dimensional images." +
" The electronic systems used to play video games are known as platforms; examples of these are personal" +
" computers and video game consoles. These platforms range from large mainframe computers to small handheld devices." +
" Specialized video games such as arcade games, while previously common, have gradually declined in use.";
author = "Ign";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(categoryFieldName, "videogames", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc);
doc = new Document();
title = "Video games: the history across PC, consoles and fun";
text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
" from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
" and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
"Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
" dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
author = "Ign";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(categoryFieldName, "videogames", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc);
doc = new Document();
title = "Video games: the history";
text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
" from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
" and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
"Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
" dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
author = "Ign";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(categoryFieldName, "videogames", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc);
doc = new Document();
title = "Batman: Arkham Knight PC Benchmarks, For What They're Worth";
text = "Although I didnt spend much time playing Batman: Arkham Origins, I remember the game rather well after" +
" testing it on no less than 30 graphics cards and 20 CPUs. Arkham Origins appeared to take full advantage of" +
" Unreal Engine 3, it ran smoothly on affordable GPUs, though its worth remembering that Origins was developed " +
"for last-gen consoles.This week marked the arrival of Batman: Arkham Knight, the fourth entry in WBs Batman:" +
" Arkham series and a direct sequel to 2013s Arkham Origins 2011s Arkham City." +
"Arkham Knight is also powered by Unreal Engine 3, but you can expect noticeably improved graphics, in part because" +
" the PlayStation 4 and Xbox One have replaced the PS3 and 360 as the lowest common denominator.";
author = "Rocksteady Studios";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(categoryFieldName, "batman", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc);
doc = new Document();
title = "Face-Off: Batman: Arkham Knight, the Dark Knight returns!";
text = "Despite the drama surrounding the PC release leading to its subsequent withdrawal, there's a sense of success" +
" in the console space as PlayStation 4 owners, and indeed those on Xbox One, get a superb rendition of Batman:" +
" Arkham Knight. It's fair to say Rocksteady sized up each console's strengths well ahead of producing its first" +
" current-gen title, and it's paid off in one of the best Batman games we've seen in years.";
author = "Rocksteady Studios";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(categoryFieldName, "batman", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc);
doc = new Document();
title = "Batman: Arkham Knight Having More Trouble, But This Time not in Gotham";
text = "As news began to break about the numerous issues affecting the PC version of Batman: Arkham Knight, players" +
" of the console version breathed a sigh of relief and got back to playing the game. Now players of the PlayStation" +
" 4 version are having problems of their own, albeit much less severe ones." +
"This time Batman will have a difficult time in Gotham.";
author = "Rocksteady Studios";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(categoryFieldName, "batman", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc);
doc = new Document();
title = "Batman: Arkham Knight the new legend of Gotham";
text = "As news began to break about the numerous issues affecting the PC version of the game, players" +
" of the console version breathed a sigh of relief and got back to play. Now players of the PlayStation" +
" 4 version are having problems of their own, albeit much less severe ones.";
author = "Rocksteady Studios";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(categoryFieldName, "batman", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc);
doc = new Document();
text = "unlabeled doc";
doc.add(new Field(textFieldName, text, ft));
indexWriter.addDocument(doc);
indexWriter.commit();
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
}
protected Document getVideoGameDocument() {
Document doc = new Document();
String title = "The new generation of PC and Console Video games";
String text = "Recently a lot of games have been released for the latest generations of consoles and personal computers." +
"One of them is Batman: Arkham Knight released recently on PS4, X-box and personal computer." +
"Another important video game that will be released in November is Assassin's Creed, a classic series that sees its new installement on Halloween." +
"Recently a lot of problems affected the Assassin's creed series but this time it should ran smoothly on affordable GPUs." +
"Players are waiting for the versions of their favourite video games and so do we.";
String author = "Ign";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(booleanFieldName, "false", ft));
return doc;
}
protected Document getBatmanDocument() {
Document doc = new Document();
String title = "Batman: Arkham Knight new adventures for the super hero across Gotham, the Dark Knight has returned!";
String title2 = "I am a second title !";
String text = "This game is the electronic version of the famous super hero adventures.It involves the interaction with the open world" +
" of the city of Gotham. Finally the player will be able to have fun on its personal device." +
" The three-dimensional images of the game are stunning, because it uses the Unreal Engine 3." +
" The systems available are PS4, X-Box and personal computer." +
" Will the simulate missile that is going to be fired, success ?\" +\n" +
" Will this video game make the history" +
" Help you favourite super hero to defeat all his enemies. The Dark Knight has returned !";
String author = "Rocksteady Studios";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(titleFieldName, title2, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(booleanFieldName, "false", ft));
return doc;
}
protected Document getBatmanAmbiguosDocument() {
Document doc = new Document();
String title = "Batman: Arkham Knight new adventures for the super hero across Gotham, the Dark Knight has returned! Batman will win !";
String text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
" from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
" and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
"Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
" dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
String author = "Ign";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(titleFieldName, title, ft));
doc.add(new Field(authorFieldName, author, ft));
doc.add(new Field(booleanFieldName, "false", ft));
return doc;
}
}

View File

@ -0,0 +1,96 @@
package org.apache.lucene.classification.document;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.document.Document;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.junit.Test;
/**
* Tests for {@link org.apache.lucene.classification.KNearestNeighborClassifier}
*/
public class KNearestNeighborDocumentClassifierTest extends DocumentClassificationTestBase<BytesRef> {
@Test
public void testBasicDocumentClassification() throws Exception {
try {
Document videoGameDocument = getVideoGameDocument();
Document batmanDocument = getBatmanDocument();
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
// considering only the text we have wrong classification because the text was ambiguos on purpose
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testBasicDocumentClassificationScore() throws Exception {
try {
Document videoGameDocument = getVideoGameDocument();
Document batmanDocument = getBatmanDocument();
double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
assertEquals(1.0,score1,0);
double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
assertEquals(1.0,score2,0);
// considering only the text we have wrong classification because the text was ambiguos on purpose
double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
assertEquals(1.0,score3,0);
double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
assertEquals(1.0,score4,0);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testBoostedDocumentClassification() throws Exception {
try {
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
// considering without boost wrong classification will appear
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testBasicDocumentClassificationWithQuery() throws Exception {
try {
TermQuery query = new TermQuery(new Term(authorFieldName, "ign"));
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT);
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
}

View File

@ -0,0 +1,76 @@
package org.apache.lucene.classification.document;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.util.BytesRef;
import org.junit.Test;
/**
* Tests for {@link org.apache.lucene.classification.SimpleNaiveBayesClassifier}
*/
public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificationTestBase<BytesRef> {
@Test
public void testBasicDocumentClassification() throws Exception {
try {
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testBasicDocumentClassificationScore() throws Exception {
try {
double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
assertEquals(0.88,score1,0.01);
double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
assertEquals(0.89,score2,0.01);
//taking in consideration only the text
double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
assertEquals(0.55,score3,0.01);
double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
assertEquals(0.52,score4,0.01);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testBoostedDocumentClassification() throws Exception {
try {
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
// considering without boost wrong classification will appear
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
}