mirror of https://github.com/apache/lucene.git
LUCENE-6631 - added document classification api and impls
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1709522 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
0fe5ab3b9b
commit
71cea88773
|
@ -2,7 +2,6 @@ package org.apache.lucene.classification;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -81,38 +80,17 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
|||
|
||||
|
||||
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||
String[] tokenizedDoc = tokenizeDoc(inputDocument);
|
||||
String[] tokenizedText = tokenize(inputDocument);
|
||||
|
||||
List<ClassificationResult<BytesRef>> dataList = calculateLogLikelihood(tokenizedDoc);
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = calculateLogLikelihood(tokenizedText);
|
||||
|
||||
// normalization
|
||||
// The values transforms to a 0-1 range
|
||||
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||
if (!dataList.isEmpty()) {
|
||||
Collections.sort(dataList);
|
||||
// this is a negative number closest to 0 = a
|
||||
double smax = dataList.get(0).getScore();
|
||||
|
||||
double sumLog = 0;
|
||||
// log(sum(exp(x_n-a)))
|
||||
for (ClassificationResult<BytesRef> cr : dataList) {
|
||||
// getScore-smax <=0 (both negative, smax is the smallest abs()
|
||||
sumLog += Math.exp(cr.getScore() - smax);
|
||||
}
|
||||
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
|
||||
double loga = smax;
|
||||
loga += Math.log(sumLog);
|
||||
|
||||
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
|
||||
for (ClassificationResult<BytesRef> cr : dataList) {
|
||||
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
|
||||
}
|
||||
}
|
||||
|
||||
return returnList;
|
||||
ArrayList<ClassificationResult<BytesRef>> asignedClassesNorm = super.normClassificationResults(assignedClasses);
|
||||
return asignedClassesNorm;
|
||||
}
|
||||
|
||||
private List<ClassificationResult<BytesRef>> calculateLogLikelihood(String[] tokenizedDoc) throws IOException {
|
||||
private List<ClassificationResult<BytesRef>> calculateLogLikelihood(String[] tokenizedText) throws IOException {
|
||||
// initialize the return List
|
||||
ArrayList<ClassificationResult<BytesRef>> ret = new ArrayList<>();
|
||||
for (BytesRef cclass : cclasses) {
|
||||
|
@ -120,7 +98,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
|||
ret.add(cr);
|
||||
}
|
||||
// for each word
|
||||
for (String word : tokenizedDoc) {
|
||||
for (String word : tokenizedText) {
|
||||
// search with text:word for all class:c
|
||||
Map<BytesRef, Integer> hitsInClasses = getWordFreqForClassess(word);
|
||||
// for each class
|
||||
|
|
|
@ -48,12 +48,12 @@ import org.apache.lucene.util.BytesRef;
|
|||
*/
|
||||
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||
|
||||
private final MoreLikeThis mlt;
|
||||
private final String[] textFieldNames;
|
||||
private final String classFieldName;
|
||||
private final IndexSearcher indexSearcher;
|
||||
private final int k;
|
||||
private final Query query;
|
||||
protected final MoreLikeThis mlt;
|
||||
protected final String[] textFieldNames;
|
||||
protected final String classFieldName;
|
||||
protected final IndexSearcher indexSearcher;
|
||||
protected final int k;
|
||||
protected final Query query;
|
||||
|
||||
/**
|
||||
* Creates a {@link KNearestNeighborClassifier}.
|
||||
|
@ -159,7 +159,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
}
|
||||
|
||||
//ranking of classes must be taken in consideration
|
||||
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
|
||||
protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
|
||||
Map<BytesRef, Integer> classCounts = new HashMap<>();
|
||||
Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
|
||||
float maxScore = topDocs.getMaxScore();
|
||||
|
|
|
@ -85,8 +85,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||
* if all the indexed docs should be used
|
||||
* @param classFieldName the name of the field used as the output for the classifier
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier
|
||||
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
|
||||
* as the returned class will be a token indexed for this field
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
|
||||
*/
|
||||
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
||||
this.leafReader = leafReader;
|
||||
|
@ -102,16 +103,16 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(inputDocument);
|
||||
ClassificationResult<BytesRef> retval = null;
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(inputDocument);
|
||||
ClassificationResult<BytesRef> assignedClass = null;
|
||||
double maxscore = -Double.MAX_VALUE;
|
||||
for (ClassificationResult<BytesRef> element : doclist) {
|
||||
if (element.getScore() > maxscore) {
|
||||
retval = element;
|
||||
maxscore = element.getScore();
|
||||
for (ClassificationResult<BytesRef> c : assignedClasses) {
|
||||
if (c.getScore() > maxscore) {
|
||||
assignedClass = c;
|
||||
maxscore = c.getScore();
|
||||
}
|
||||
}
|
||||
return retval;
|
||||
return assignedClass;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -119,9 +120,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
|
||||
Collections.sort(doclist);
|
||||
return doclist;
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
|
||||
Collections.sort(assignedClasses);
|
||||
return assignedClasses;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -129,9 +130,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
|
||||
Collections.sort(doclist);
|
||||
return doclist.subList(0, max);
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
|
||||
Collections.sort(assignedClasses);
|
||||
return assignedClasses.subList(0, max);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -141,46 +142,26 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
* @throws IOException if assigning probabilities fails
|
||||
*/
|
||||
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||
|
||||
Terms terms = MultiFields.getTerms(leafReader, classFieldName);
|
||||
TermsEnum termsEnum = terms.iterator();
|
||||
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
||||
TermsEnum classesEnum = classes.iterator();
|
||||
BytesRef next;
|
||||
String[] tokenizedDoc = tokenizeDoc(inputDocument);
|
||||
String[] tokenizedText = tokenize(inputDocument);
|
||||
int docsWithClassSize = countDocsWithClass();
|
||||
while ((next = termsEnum.next()) != null) {
|
||||
while ((next = classesEnum.next()) != null) {
|
||||
if (next.length > 0) {
|
||||
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
|
||||
next = BytesRef.deepCopyOf(next);
|
||||
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
|
||||
dataList.add(new ClassificationResult<>(next, clVal));
|
||||
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedText, next, docsWithClassSize);
|
||||
assignedClasses.add(new ClassificationResult<>(next, clVal));
|
||||
}
|
||||
}
|
||||
|
||||
// normalization; the values transforms to a 0-1 range
|
||||
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||
if (!dataList.isEmpty()) {
|
||||
Collections.sort(dataList);
|
||||
// this is a negative number closest to 0 = a
|
||||
double smax = dataList.get(0).getScore();
|
||||
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
|
||||
|
||||
double sumLog = 0;
|
||||
// log(sum(exp(x_n-a)))
|
||||
for (ClassificationResult<BytesRef> cr : dataList) {
|
||||
// getScore-smax <=0 (both negative, smax is the smallest abs()
|
||||
sumLog += Math.exp(cr.getScore() - smax);
|
||||
}
|
||||
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
|
||||
double loga = smax;
|
||||
loga += Math.log(sumLog);
|
||||
|
||||
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
|
||||
for (ClassificationResult<BytesRef> cr : dataList) {
|
||||
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
|
||||
}
|
||||
}
|
||||
|
||||
return returnList;
|
||||
return assignedClassesNorm;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -192,15 +173,15 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
protected int countDocsWithClass() throws IOException {
|
||||
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
|
||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
||||
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
||||
BooleanQuery.Builder q = new BooleanQuery.Builder();
|
||||
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
||||
if (query != null) {
|
||||
q.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
indexSearcher.search(q.build(),
|
||||
totalHitCountCollector);
|
||||
docCount = totalHitCountCollector.getTotalHits();
|
||||
classQueryCountCollector);
|
||||
docCount = classQueryCountCollector.getTotalHits();
|
||||
}
|
||||
return docCount;
|
||||
}
|
||||
|
@ -208,14 +189,14 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
/**
|
||||
* tokenize a <code>String</code> on this classifier's text fields and analyzer
|
||||
*
|
||||
* @param doc the <code>String</code> representing an input text (to be classified)
|
||||
* @param text the <code>String</code> representing an input text (to be classified)
|
||||
* @return a <code>String</code> array of the resulting tokens
|
||||
* @throws IOException if tokenization fails
|
||||
*/
|
||||
protected String[] tokenizeDoc(String doc) throws IOException {
|
||||
protected String[] tokenize(String text) throws IOException {
|
||||
Collection<String> result = new LinkedList<>();
|
||||
for (String textFieldName : textFieldNames) {
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
||||
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
|
||||
tokenStream.reset();
|
||||
while (tokenStream.incrementToken()) {
|
||||
|
@ -227,18 +208,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
return result.toArray(new String[result.size()]);
|
||||
}
|
||||
|
||||
private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c, int docsWithClassSize) throws IOException {
|
||||
private double calculateLogLikelihood(String[] tokenizedText, BytesRef c, int docsWithClass) throws IOException {
|
||||
// for each word
|
||||
double result = 0d;
|
||||
for (String word : tokenizedDoc) {
|
||||
for (String word : tokenizedText) {
|
||||
// search with text:word AND class:c
|
||||
int hits = getWordFreqForClass(word, c);
|
||||
int hits = getWordFreqForClass(word,c);
|
||||
|
||||
// num : count the no of times the word appears in documents of class c (+1)
|
||||
double num = hits + 1; // +1 is added because of add 1 smoothing
|
||||
|
||||
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
|
||||
double den = getTextTermFreqForClass(c) + docsWithClassSize;
|
||||
double den = getTextTermFreqForClass(c) + docsWithClass;
|
||||
|
||||
// P(w|c) = num/den
|
||||
double wordProbability = num / den;
|
||||
|
@ -249,6 +230,12 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the average number of unique terms times the number of docs belonging to the input class
|
||||
* @param c the class
|
||||
* @return the average number of unique terms
|
||||
* @throws IOException if a low level I/O problem happens
|
||||
*/
|
||||
private double getTextTermFreqForClass(BytesRef c) throws IOException {
|
||||
double avgNumberOfUniqueTerms = 0;
|
||||
for (String textFieldName : textFieldNames) {
|
||||
|
@ -260,6 +247,14 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of documents of the input class ( from the whole index or from a subset)
|
||||
* that contains the word ( in a specific field or in all the fields if no one selected)
|
||||
* @param word the token produced by the analyzer
|
||||
* @param c the class
|
||||
* @return the number of documents of the input class
|
||||
* @throws IOException if a low level I/O problem happens
|
||||
*/
|
||||
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
|
||||
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
|
||||
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
|
||||
|
@ -283,4 +278,36 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
private int docCount(BytesRef countedClass) throws IOException {
|
||||
return leafReader.docFreq(new Term(classFieldName, countedClass));
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize the classification results based on the max score available
|
||||
* @param assignedClasses the list of assigned classes
|
||||
* @return the normalized results
|
||||
*/
|
||||
protected ArrayList<ClassificationResult<BytesRef>> normClassificationResults(List<ClassificationResult<BytesRef>> assignedClasses) {
|
||||
// normalization; the values transforms to a 0-1 range
|
||||
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||
if (!assignedClasses.isEmpty()) {
|
||||
Collections.sort(assignedClasses);
|
||||
// this is a negative number closest to 0 = a
|
||||
double smax = assignedClasses.get(0).getScore();
|
||||
|
||||
double sumLog = 0;
|
||||
// log(sum(exp(x_n-a)))
|
||||
for (ClassificationResult<BytesRef> cr : assignedClasses) {
|
||||
// getScore-smax <=0 (both negative, smax is the smallest abs()
|
||||
sumLog += Math.exp(cr.getScore() - smax);
|
||||
}
|
||||
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
|
||||
double loga = smax;
|
||||
loga += Math.log(sumLog);
|
||||
|
||||
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
|
||||
for (ClassificationResult<BytesRef> cr : assignedClasses) {
|
||||
double scoreDiff = cr.getScore() - loga;
|
||||
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff)));
|
||||
}
|
||||
}
|
||||
return returnList;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
package org.apache.lucene.classification.document;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.document.Document;
|
||||
|
||||
/**
|
||||
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
|
||||
* <code>T</code> to a {@link org.apache.lucene.document.Document}s
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public interface DocumentClassifier<T> {
|
||||
/**
|
||||
* Assign a class (with score) to the given {@link org.apache.lucene.document.Document}
|
||||
*
|
||||
* @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
|
||||
* @return a {@link org.apache.lucene.classification.ClassificationResult} holding assigned class of type <code>T</code> and score
|
||||
* @throws java.io.IOException If there is a low-level I/O error.
|
||||
*/
|
||||
ClassificationResult<T> assignClass(Document document) throws IOException;
|
||||
|
||||
/**
|
||||
* Get all the classes (sorted by score, descending) assigned to the given {@link org.apache.lucene.document.Document}.
|
||||
*
|
||||
* @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
|
||||
* @return the whole list of {@link org.apache.lucene.classification.ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
|
||||
* @throws java.io.IOException If there is a low-level I/O error.
|
||||
*/
|
||||
List<ClassificationResult<T>> getClasses(Document document) throws IOException;
|
||||
|
||||
/**
|
||||
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
|
||||
*
|
||||
* @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
|
||||
* @param max the number of return list elements
|
||||
* @return the whole list of {@link org.apache.lucene.classification.ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
|
||||
* @throws java.io.IOException If there is a low-level I/O error.
|
||||
*/
|
||||
List<ClassificationResult<T>> getClasses(Document document, int max) throws IOException;
|
||||
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
package org.apache.lucene.classification.document;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.StringReader;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.classification.KNearestNeighborClassifier;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* A k-Nearest Neighbor Document classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors</code>) based
|
||||
* on {@link org.apache.lucene.queries.mlt.MoreLikeThis} .
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifier implements DocumentClassifier<BytesRef> {
|
||||
protected Map<String, Analyzer> field2analyzer;
|
||||
|
||||
/**
|
||||
* Creates a {@link KNearestNeighborClassifier}.
|
||||
*
|
||||
* @param leafReader the reader on the index to be used for classification
|
||||
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
|
||||
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
|
||||
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||
* if all the indexed docs should be used
|
||||
* @param k the no. of docs to select in the MLT results to find the nearest neighbor
|
||||
* @param minDocsFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq} parameter
|
||||
* @param minTermFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq} parameter
|
||||
* @param classFieldName the name of the field used as the output for the classifier
|
||||
* @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer}
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||
*/
|
||||
public KNearestNeighborDocumentClassifier(LeafReader leafReader, Similarity similarity, Query query, int k, int minDocsFreq,
|
||||
int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
||||
super(leafReader,similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
|
||||
this.field2analyzer = field2analyzer;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
|
||||
TopDocs knnResults = knnSearch(document);
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||
ClassificationResult<BytesRef> assignedClass = null;
|
||||
double maxscore = -Double.MAX_VALUE;
|
||||
for (ClassificationResult<BytesRef> cl : assignedClasses) {
|
||||
if (cl.getScore() > maxscore) {
|
||||
assignedClass = cl;
|
||||
maxscore = cl.getScore();
|
||||
}
|
||||
}
|
||||
return assignedClass;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
|
||||
TopDocs knnResults = knnSearch(document);
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||
Collections.sort(assignedClasses);
|
||||
return assignedClasses;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
|
||||
TopDocs knnResults = knnSearch(document);
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||
Collections.sort(assignedClasses);
|
||||
return assignedClasses.subList(0, max);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the top k results from a More Like This query based on the input document
|
||||
*
|
||||
* @param document the document to use for More Like This search
|
||||
* @return the top results for the MLT query
|
||||
* @throws IOException If there is a low-level I/O error
|
||||
*/
|
||||
private TopDocs knnSearch(Document document) throws IOException {
|
||||
BooleanQuery.Builder mltQuery = new BooleanQuery.Builder();
|
||||
|
||||
for (String fieldName : textFieldNames) {
|
||||
String boost = null;
|
||||
if (fieldName.contains("^")) {
|
||||
String[] field2boost = fieldName.split("\\^");
|
||||
fieldName = field2boost[0];
|
||||
boost = field2boost[1];
|
||||
}
|
||||
String[] fieldValues = document.getValues(fieldName);
|
||||
if (boost != null) {
|
||||
mlt.setBoost(true);
|
||||
mlt.setBoostFactor(Float.parseFloat(boost));
|
||||
}
|
||||
mlt.setAnalyzer(field2analyzer.get(fieldName));
|
||||
for (String fieldContent : fieldValues) {
|
||||
mltQuery.add(new BooleanClause(mlt.like(fieldName, new StringReader(fieldContent)), BooleanClause.Occur.SHOULD));
|
||||
}
|
||||
mlt.setBoost(false);
|
||||
}
|
||||
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
|
||||
mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
|
||||
if (query != null) {
|
||||
mltQuery.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
return indexSearcher.search(mltQuery.build(), k);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,289 @@
|
|||
package org.apache.lucene.classification.document;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.TotalHitCountCollector;
|
||||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* A simplistic Lucene based NaiveBayes classifier, see {@code http://en.wikipedia.org/wiki/Naive_Bayes_classifier}
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifier implements DocumentClassifier<BytesRef> {
|
||||
/**
|
||||
* {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing document fields
|
||||
*/
|
||||
protected Map<String, Analyzer> field2analyzer;
|
||||
|
||||
/**
|
||||
* Creates a new NaiveBayes classifier.
|
||||
*
|
||||
* @param leafReader the reader on the index to be used for classification
|
||||
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||
* if all the indexed docs should be used
|
||||
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
|
||||
* as the returned class will be a token indexed for this field
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||
*/
|
||||
public SimpleNaiveBayesDocumentClassifier(LeafReader leafReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
||||
super(leafReader, null, query, classFieldName, textFieldNames);
|
||||
this.field2analyzer = field2analyzer;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
|
||||
ClassificationResult<BytesRef> assignedClass = null;
|
||||
double maxscore = -Double.MAX_VALUE;
|
||||
for (ClassificationResult<BytesRef> c : assignedClasses) {
|
||||
if (c.getScore() > maxscore) {
|
||||
assignedClass = c;
|
||||
maxscore = c.getScore();
|
||||
}
|
||||
}
|
||||
return assignedClass;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
|
||||
Collections.sort(assignedClasses);
|
||||
return assignedClasses;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
|
||||
Collections.sort(assignedClasses);
|
||||
return assignedClasses.subList(0, max);
|
||||
}
|
||||
|
||||
private List<ClassificationResult<BytesRef>> assignNormClasses(Document inputDocument) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||
Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
|
||||
Map<String, Float> fieldName2boost = new LinkedHashMap<>();
|
||||
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
||||
TermsEnum classesEnum = classes.iterator();
|
||||
BytesRef c;
|
||||
|
||||
analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost);
|
||||
|
||||
int docsWithClassSize = countDocsWithClass();
|
||||
while ((c = classesEnum.next()) != null) {
|
||||
double classScore = 0;
|
||||
for (String fieldName : textFieldNames) {
|
||||
List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName);
|
||||
double fieldScore = 0;
|
||||
for (String[] fieldTokensArray : tokensArrays) {
|
||||
fieldScore += calculateLogPrior(c, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, c, docsWithClassSize) * fieldName2boost.get(fieldName);
|
||||
}
|
||||
classScore += fieldScore;
|
||||
}
|
||||
assignedClasses.add(new ClassificationResult<>(BytesRef.deepCopyOf(c), classScore));
|
||||
}
|
||||
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
|
||||
return assignedClassesNorm;
|
||||
}
|
||||
|
||||
/**
|
||||
* This methods performs the analysis for the seed document and extract the boosts if present.
|
||||
* This is done only one time for the Seed Document.
|
||||
*
|
||||
* @param inputDocument the seed unseen document
|
||||
* @param fieldName2tokensArray a map that associated to a field name the list of token arrays for all its values
|
||||
* @param fieldName2boost a map that associates the boost to the field
|
||||
* @throws IOException If there is a low-level I/O error
|
||||
*/
|
||||
private void analyzeSeedDocument(Document inputDocument, Map<String, List<String[]>> fieldName2tokensArray, Map<String, Float> fieldName2boost) throws IOException {
|
||||
for (int i = 0; i < textFieldNames.length; i++) {
|
||||
String fieldName = textFieldNames[i];
|
||||
float boost = 1;
|
||||
List<String[]> tokenizedValues = new LinkedList<>();
|
||||
if (fieldName.contains("^")) {
|
||||
String[] field2boost = fieldName.split("\\^");
|
||||
fieldName = field2boost[0];
|
||||
boost = Float.parseFloat(field2boost[1]);
|
||||
}
|
||||
Field[] fieldValues = inputDocument.getFields(fieldName);
|
||||
for (Field fieldValue : fieldValues) {
|
||||
TokenStream fieldTokens = fieldValue.tokenStream(field2analyzer.get(fieldName), null);
|
||||
String[] fieldTokensArray = getTokenArray(fieldTokens);
|
||||
tokenizedValues.add(fieldTokensArray);
|
||||
}
|
||||
fieldName2tokensArray.put(fieldName, tokenizedValues);
|
||||
fieldName2boost.put(fieldName, boost);
|
||||
textFieldNames[i] = fieldName;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts the number of documents in the index having at least a value for the 'class' field
|
||||
*
|
||||
* @return the no. of documents having a value for the 'class' field
|
||||
* @throws java.io.IOException If accessing to term vectors or search fails
|
||||
*/
|
||||
protected int countDocsWithClass() throws IOException {
|
||||
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
|
||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
||||
BooleanQuery.Builder q = new BooleanQuery.Builder();
|
||||
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
||||
if (query != null) {
|
||||
q.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
indexSearcher.search(q.build(),
|
||||
classQueryCountCollector);
|
||||
docCount = classQueryCountCollector.getTotalHits();
|
||||
}
|
||||
return docCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input
|
||||
*
|
||||
* @param tokenizedText the tokenized content of a field
|
||||
* @return a {@code String} array of the resulting tokens
|
||||
* @throws java.io.IOException If tokenization fails because there is a low-level I/O error
|
||||
*/
|
||||
protected String[] getTokenArray(TokenStream tokenizedText) throws IOException {
|
||||
Collection<String> tokens = new LinkedList<>();
|
||||
CharTermAttribute charTermAttribute = tokenizedText.addAttribute(CharTermAttribute.class);
|
||||
tokenizedText.reset();
|
||||
while (tokenizedText.incrementToken()) {
|
||||
tokens.add(charTermAttribute.toString());
|
||||
}
|
||||
tokenizedText.end();
|
||||
tokenizedText.close();
|
||||
return tokens.toArray(new String[tokens.size()]);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param tokenizedText the tokenized content of a field
|
||||
* @param fieldName the input field name
|
||||
* @param c the class to calculate the score of
|
||||
* @param docsWithClass the total number of docs that have a class
|
||||
* @return a normalized score for the class
|
||||
* @throws IOException If there is a low-level I/O error
|
||||
*/
|
||||
private double calculateLogLikelihood(String[] tokenizedText, String fieldName, BytesRef c, int docsWithClass) throws IOException {
|
||||
// for each word
|
||||
double result = 0d;
|
||||
for (String word : tokenizedText) {
|
||||
// search with text:word AND class:c
|
||||
int hits = getWordFreqForClass(word, fieldName, c);
|
||||
|
||||
// num : count the no of times the word appears in documents of class c (+1)
|
||||
double num = hits + 1; // +1 is added because of add 1 smoothing
|
||||
|
||||
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
|
||||
double den = getTextTermFreqForClass(c, fieldName) + docsWithClass;
|
||||
|
||||
// P(w|c) = num/den
|
||||
double wordProbability = num / den;
|
||||
result += Math.log(wordProbability);
|
||||
}
|
||||
|
||||
// log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c))
|
||||
double normScore = result / (tokenizedText.length); // this is normalized because if not, long text fields will always be more important than short fields
|
||||
return normScore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the average number of unique terms times the number of docs belonging to the input class
|
||||
*
|
||||
* @param c the class
|
||||
* @return the average number of unique terms
|
||||
* @throws java.io.IOException If there is a low-level I/O error
|
||||
*/
|
||||
private double getTextTermFreqForClass(BytesRef c, String fieldName) throws IOException {
|
||||
double avgNumberOfUniqueTerms;
|
||||
Terms terms = MultiFields.getTerms(leafReader, fieldName);
|
||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||
int docsWithC = leafReader.docFreq(new Term(classFieldName, c));
|
||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of documents of the input class ( from the whole index or from a subset)
|
||||
* that contains the word ( in a specific field or in all the fields if no one selected)
|
||||
*
|
||||
* @param word the token produced by the analyzer
|
||||
* @param fieldName the field the word is coming from
|
||||
* @param c the class
|
||||
* @return number of documents of the input class
|
||||
* @throws java.io.IOException If there is a low-level I/O error
|
||||
*/
|
||||
private int getWordFreqForClass(String word, String fieldName, BytesRef c) throws IOException {
|
||||
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
|
||||
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
|
||||
subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD));
|
||||
booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
|
||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
||||
if (query != null) {
|
||||
booleanQuery.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
||||
indexSearcher.search(booleanQuery.build(), totalHitCountCollector);
|
||||
return totalHitCountCollector.getTotalHits();
|
||||
}
|
||||
|
||||
private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
|
||||
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
|
||||
}
|
||||
|
||||
private int docCount(BytesRef countedClass) throws IOException {
|
||||
return leafReader.docFreq(new Term(classFieldName, countedClass));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
/**
|
||||
* Uses already seen data (the indexed documents) to classify new documents.
|
||||
* <p>
|
||||
* Currently contains a (simplistic) Naive Bayes classifier and a k-Nearest
|
||||
* Neighbor classifier.
|
||||
*/
|
||||
package org.apache.lucene.classification.document;
|
|
@ -16,7 +16,7 @@
|
|||
*/
|
||||
|
||||
/**
|
||||
* Uses already seen data (the indexed documents) to classify new documents.
|
||||
* Uses already seen data (the indexed documents) to classify an input ( can be simple text or a structured document).
|
||||
* <p>
|
||||
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
|
||||
* Neighbor classifier and a Perceptron based classifier.
|
||||
|
|
|
@ -57,8 +57,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
protected static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
||||
|
||||
protected RandomIndexWriter indexWriter;
|
||||
private Directory dir;
|
||||
private FieldType ft;
|
||||
protected Directory dir;
|
||||
protected FieldType ft;
|
||||
|
||||
protected String textFieldName;
|
||||
protected String categoryFieldName;
|
||||
|
|
|
@ -0,0 +1,259 @@
|
|||
package org.apache.lucene.classification.document;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.en.EnglishAnalyzer;
|
||||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.classification.ClassificationTestBase;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Before;
|
||||
|
||||
/**
|
||||
* Base class for testing {@link org.apache.lucene.classification.Classifier}s
|
||||
*/
|
||||
public abstract class DocumentClassificationTestBase<T> extends ClassificationTestBase {
|
||||
|
||||
protected static final BytesRef VIDEOGAME_RESULT = new BytesRef("videogames");
|
||||
protected static final BytesRef VIDEOGAME_ANALYZED_RESULT = new BytesRef("videogam");
|
||||
protected static final BytesRef BATMAN_RESULT = new BytesRef("batman");
|
||||
|
||||
protected String titleFieldName = "title";
|
||||
protected String authorFieldName = "author";
|
||||
|
||||
protected Analyzer analyzer;
|
||||
protected Map<String, Analyzer> field2analyzer;
|
||||
protected LeafReader leafReader;
|
||||
|
||||
@Before
|
||||
public void init() throws IOException {
|
||||
analyzer = new EnglishAnalyzer();
|
||||
field2analyzer = new LinkedHashMap<>();
|
||||
field2analyzer.put(textFieldName, analyzer);
|
||||
field2analyzer.put(titleFieldName, analyzer);
|
||||
field2analyzer.put(authorFieldName, analyzer);
|
||||
leafReader = populateDocumentClassificationIndex(analyzer);
|
||||
}
|
||||
|
||||
protected double checkCorrectDocumentClassification(DocumentClassifier<T> classifier, Document inputDoc, T expectedResult) throws Exception {
|
||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||
assertNotNull(classificationResult.getAssignedClass());
|
||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||
double score = classificationResult.getScore();
|
||||
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
|
||||
return score;
|
||||
}
|
||||
|
||||
protected LeafReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException {
|
||||
indexWriter.close();
|
||||
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
|
||||
indexWriter.commit();
|
||||
String text;
|
||||
String title;
|
||||
String author;
|
||||
|
||||
Document doc = new Document();
|
||||
title = "Video games are an economic business";
|
||||
text = "Video games have become an art form and an industry. The video game industry is of increasing" +
|
||||
" commercial importance, with growth driven particularly by the emerging Asian markets and mobile games." +
|
||||
" As of 2015, video games generated sales of USD 74 billion annually worldwide, and were the third-largest" +
|
||||
" segment in the U.S. entertainment market, behind broadcast and cable TV.";
|
||||
author = "Ign";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(categoryFieldName, "videogames", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
indexWriter.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
title = "Video games: the definition of fun on PC and consoles";
|
||||
text = "A video game is an electronic game that involves human interaction with a user interface to generate" +
|
||||
" visual feedback on a video device. The word video in video game traditionally referred to a raster display device," +
|
||||
"[1] but it now implies any type of display device that can produce two- or three-dimensional images." +
|
||||
" The electronic systems used to play video games are known as platforms; examples of these are personal" +
|
||||
" computers and video game consoles. These platforms range from large mainframe computers to small handheld devices." +
|
||||
" Specialized video games such as arcade games, while previously common, have gradually declined in use.";
|
||||
author = "Ign";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(categoryFieldName, "videogames", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
indexWriter.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
title = "Video games: the history across PC, consoles and fun";
|
||||
text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
|
||||
" from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
|
||||
" and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
|
||||
"Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
|
||||
" dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
|
||||
author = "Ign";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(categoryFieldName, "videogames", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
indexWriter.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
title = "Video games: the history";
|
||||
text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
|
||||
" from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
|
||||
" and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
|
||||
"Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
|
||||
" dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
|
||||
author = "Ign";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(categoryFieldName, "videogames", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
indexWriter.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
title = "Batman: Arkham Knight PC Benchmarks, For What They're Worth";
|
||||
text = "Although I didn’t spend much time playing Batman: Arkham Origins, I remember the game rather well after" +
|
||||
" testing it on no less than 30 graphics cards and 20 CPUs. Arkham Origins appeared to take full advantage of" +
|
||||
" Unreal Engine 3, it ran smoothly on affordable GPUs, though it’s worth remembering that Origins was developed " +
|
||||
"for last-gen consoles.This week marked the arrival of Batman: Arkham Knight, the fourth entry in WB’s Batman:" +
|
||||
" Arkham series and a direct sequel to 2013’s Arkham Origins 2011’s Arkham City." +
|
||||
"Arkham Knight is also powered by Unreal Engine 3, but you can expect noticeably improved graphics, in part because" +
|
||||
" the PlayStation 4 and Xbox One have replaced the PS3 and 360 as the lowest common denominator.";
|
||||
author = "Rocksteady Studios";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(categoryFieldName, "batman", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
indexWriter.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
title = "Face-Off: Batman: Arkham Knight, the Dark Knight returns!";
|
||||
text = "Despite the drama surrounding the PC release leading to its subsequent withdrawal, there's a sense of success" +
|
||||
" in the console space as PlayStation 4 owners, and indeed those on Xbox One, get a superb rendition of Batman:" +
|
||||
" Arkham Knight. It's fair to say Rocksteady sized up each console's strengths well ahead of producing its first" +
|
||||
" current-gen title, and it's paid off in one of the best Batman games we've seen in years.";
|
||||
author = "Rocksteady Studios";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(categoryFieldName, "batman", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
indexWriter.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
title = "Batman: Arkham Knight Having More Trouble, But This Time not in Gotham";
|
||||
text = "As news began to break about the numerous issues affecting the PC version of Batman: Arkham Knight, players" +
|
||||
" of the console version breathed a sigh of relief and got back to playing the game. Now players of the PlayStation" +
|
||||
" 4 version are having problems of their own, albeit much less severe ones." +
|
||||
"This time Batman will have a difficult time in Gotham.";
|
||||
author = "Rocksteady Studios";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(categoryFieldName, "batman", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
indexWriter.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
title = "Batman: Arkham Knight the new legend of Gotham";
|
||||
text = "As news began to break about the numerous issues affecting the PC version of the game, players" +
|
||||
" of the console version breathed a sigh of relief and got back to play. Now players of the PlayStation" +
|
||||
" 4 version are having problems of their own, albeit much less severe ones.";
|
||||
author = "Rocksteady Studios";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(categoryFieldName, "batman", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
indexWriter.addDocument(doc);
|
||||
|
||||
|
||||
doc = new Document();
|
||||
text = "unlabeled doc";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
indexWriter.addDocument(doc);
|
||||
|
||||
indexWriter.commit();
|
||||
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||
}
|
||||
|
||||
protected Document getVideoGameDocument() {
|
||||
Document doc = new Document();
|
||||
String title = "The new generation of PC and Console Video games";
|
||||
String text = "Recently a lot of games have been released for the latest generations of consoles and personal computers." +
|
||||
"One of them is Batman: Arkham Knight released recently on PS4, X-box and personal computer." +
|
||||
"Another important video game that will be released in November is Assassin's Creed, a classic series that sees its new installement on Halloween." +
|
||||
"Recently a lot of problems affected the Assassin's creed series but this time it should ran smoothly on affordable GPUs." +
|
||||
"Players are waiting for the versions of their favourite video games and so do we.";
|
||||
String author = "Ign";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
return doc;
|
||||
}
|
||||
|
||||
protected Document getBatmanDocument() {
|
||||
Document doc = new Document();
|
||||
String title = "Batman: Arkham Knight new adventures for the super hero across Gotham, the Dark Knight has returned!";
|
||||
String title2 = "I am a second title !";
|
||||
String text = "This game is the electronic version of the famous super hero adventures.It involves the interaction with the open world" +
|
||||
" of the city of Gotham. Finally the player will be able to have fun on its personal device." +
|
||||
" The three-dimensional images of the game are stunning, because it uses the Unreal Engine 3." +
|
||||
" The systems available are PS4, X-Box and personal computer." +
|
||||
" Will the simulate missile that is going to be fired, success ?\" +\n" +
|
||||
" Will this video game make the history" +
|
||||
" Help you favourite super hero to defeat all his enemies. The Dark Knight has returned !";
|
||||
String author = "Rocksteady Studios";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(titleFieldName, title2, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
return doc;
|
||||
}
|
||||
|
||||
protected Document getBatmanAmbiguosDocument() {
|
||||
Document doc = new Document();
|
||||
String title = "Batman: Arkham Knight new adventures for the super hero across Gotham, the Dark Knight has returned! Batman will win !";
|
||||
String text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
|
||||
" from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
|
||||
" and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
|
||||
"Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
|
||||
" dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
|
||||
String author = "Ign";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(titleFieldName, title, ft));
|
||||
doc.add(new Field(authorFieldName, author, ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
return doc;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package org.apache.lucene.classification.document;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Tests for {@link org.apache.lucene.classification.KNearestNeighborClassifier}
|
||||
*/
|
||||
public class KNearestNeighborDocumentClassifierTest extends DocumentClassificationTestBase<BytesRef> {
|
||||
|
||||
@Test
|
||||
public void testBasicDocumentClassification() throws Exception {
|
||||
try {
|
||||
Document videoGameDocument = getVideoGameDocument();
|
||||
Document batmanDocument = getBatmanDocument();
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicDocumentClassificationScore() throws Exception {
|
||||
try {
|
||||
Document videoGameDocument = getVideoGameDocument();
|
||||
Document batmanDocument = getBatmanDocument();
|
||||
double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||
assertEquals(1.0,score1,0);
|
||||
double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||
assertEquals(1.0,score2,0);
|
||||
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
||||
double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||
assertEquals(1.0,score3,0);
|
||||
double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||
assertEquals(1.0,score4,0);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBoostedDocumentClassification() throws Exception {
|
||||
try {
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
||||
// considering without boost wrong classification will appear
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicDocumentClassificationWithQuery() throws Exception {
|
||||
try {
|
||||
TermQuery query = new TermQuery(new Term(authorFieldName, "ign"));
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package org.apache.lucene.classification.document;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Tests for {@link org.apache.lucene.classification.SimpleNaiveBayesClassifier}
|
||||
*/
|
||||
public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificationTestBase<BytesRef> {
|
||||
|
||||
@Test
|
||||
public void testBasicDocumentClassification() throws Exception {
|
||||
try {
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
||||
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicDocumentClassificationScore() throws Exception {
|
||||
try {
|
||||
double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
assertEquals(0.88,score1,0.01);
|
||||
double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
||||
assertEquals(0.89,score2,0.01);
|
||||
//taking in consideration only the text
|
||||
double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
||||
assertEquals(0.55,score3,0.01);
|
||||
double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
assertEquals(0.52,score4,0.01);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBoostedDocumentClassification() throws Exception {
|
||||
try {
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
||||
// considering without boost wrong classification will appear
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
Loading…
Reference in New Issue