diff --git a/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java index 1be235d4e91..f0d909f8d14 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java @@ -2,7 +2,6 @@ package org.apache.lucene.classification; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -81,38 +80,17 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier { protected List> assignClassNormalizedList(String inputDocument) throws IOException { - String[] tokenizedDoc = tokenizeDoc(inputDocument); + String[] tokenizedText = tokenize(inputDocument); - List> dataList = calculateLogLikelihood(tokenizedDoc); + List> assignedClasses = calculateLogLikelihood(tokenizedText); // normalization // The values transforms to a 0-1 range - ArrayList> returnList = new ArrayList<>(); - if (!dataList.isEmpty()) { - Collections.sort(dataList); - // this is a negative number closest to 0 = a - double smax = dataList.get(0).getScore(); - - double sumLog = 0; - // log(sum(exp(x_n-a))) - for (ClassificationResult cr : dataList) { - // getScore-smax <=0 (both negative, smax is the smallest abs() - sumLog += Math.exp(cr.getScore() - smax); - } - // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n))) - double loga = smax; - loga += Math.log(sumLog); - - // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum)) - for (ClassificationResult cr : dataList) { - returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga))); - } - } - - return returnList; + ArrayList> asignedClassesNorm = super.normClassificationResults(assignedClasses); + return asignedClassesNorm; } - private List> calculateLogLikelihood(String[] tokenizedDoc) throws IOException { + private List> calculateLogLikelihood(String[] tokenizedText) throws IOException { // initialize the return List ArrayList> ret = new ArrayList<>(); for (BytesRef cclass : cclasses) { @@ -120,7 +98,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier { ret.add(cr); } // for each word - for (String word : tokenizedDoc) { + for (String word : tokenizedText) { // search with text:word for all class:c Map hitsInClasses = getWordFreqForClassess(word); // for each class diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index d088787f1fd..6b49243bc84 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -48,12 +48,12 @@ import org.apache.lucene.util.BytesRef; */ public class KNearestNeighborClassifier implements Classifier { - private final MoreLikeThis mlt; - private final String[] textFieldNames; - private final String classFieldName; - private final IndexSearcher indexSearcher; - private final int k; - private final Query query; + protected final MoreLikeThis mlt; + protected final String[] textFieldNames; + protected final String classFieldName; + protected final IndexSearcher indexSearcher; + protected final int k; + protected final Query query; /** * Creates a {@link KNearestNeighborClassifier}. @@ -159,7 +159,7 @@ public class KNearestNeighborClassifier implements Classifier { } //ranking of classes must be taken in consideration - private List> buildListFromTopDocs(TopDocs topDocs) throws IOException { + protected List> buildListFromTopDocs(TopDocs topDocs) throws IOException { Map classCounts = new HashMap<>(); Map classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs float maxScore = topDocs.getMaxScore(); diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index b236b5ac4ec..85690882480 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -85,8 +85,9 @@ public class SimpleNaiveBayesClassifier implements Classifier { * @param analyzer an {@link Analyzer} used to analyze unseen text * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null} * if all the indexed docs should be used - * @param classFieldName the name of the field used as the output for the classifier - * @param textFieldNames the name of the fields used as the inputs for the classifier + * @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed + * as the returned class will be a token indexed for this field + * @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field */ public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) { this.leafReader = leafReader; @@ -102,16 +103,16 @@ public class SimpleNaiveBayesClassifier implements Classifier { */ @Override public ClassificationResult assignClass(String inputDocument) throws IOException { - List> doclist = assignClassNormalizedList(inputDocument); - ClassificationResult retval = null; + List> assignedClasses = assignClassNormalizedList(inputDocument); + ClassificationResult assignedClass = null; double maxscore = -Double.MAX_VALUE; - for (ClassificationResult element : doclist) { - if (element.getScore() > maxscore) { - retval = element; - maxscore = element.getScore(); + for (ClassificationResult c : assignedClasses) { + if (c.getScore() > maxscore) { + assignedClass = c; + maxscore = c.getScore(); } } - return retval; + return assignedClass; } /** @@ -119,9 +120,9 @@ public class SimpleNaiveBayesClassifier implements Classifier { */ @Override public List> getClasses(String text) throws IOException { - List> doclist = assignClassNormalizedList(text); - Collections.sort(doclist); - return doclist; + List> assignedClasses = assignClassNormalizedList(text); + Collections.sort(assignedClasses); + return assignedClasses; } /** @@ -129,9 +130,9 @@ public class SimpleNaiveBayesClassifier implements Classifier { */ @Override public List> getClasses(String text, int max) throws IOException { - List> doclist = assignClassNormalizedList(text); - Collections.sort(doclist); - return doclist.subList(0, max); + List> assignedClasses = assignClassNormalizedList(text); + Collections.sort(assignedClasses); + return assignedClasses.subList(0, max); } /** @@ -141,46 +142,26 @@ public class SimpleNaiveBayesClassifier implements Classifier { * @throws IOException if assigning probabilities fails */ protected List> assignClassNormalizedList(String inputDocument) throws IOException { - List> dataList = new ArrayList<>(); + List> assignedClasses = new ArrayList<>(); - Terms terms = MultiFields.getTerms(leafReader, classFieldName); - TermsEnum termsEnum = terms.iterator(); + Terms classes = MultiFields.getTerms(leafReader, classFieldName); + TermsEnum classesEnum = classes.iterator(); BytesRef next; - String[] tokenizedDoc = tokenizeDoc(inputDocument); + String[] tokenizedText = tokenize(inputDocument); int docsWithClassSize = countDocsWithClass(); - while ((next = termsEnum.next()) != null) { + while ((next = classesEnum.next()) != null) { if (next.length > 0) { // We are passing the term to IndexSearcher so we need to make sure it will not change over time next = BytesRef.deepCopyOf(next); - double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize); - dataList.add(new ClassificationResult<>(next, clVal)); + double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedText, next, docsWithClassSize); + assignedClasses.add(new ClassificationResult<>(next, clVal)); } } // normalization; the values transforms to a 0-1 range - ArrayList> returnList = new ArrayList<>(); - if (!dataList.isEmpty()) { - Collections.sort(dataList); - // this is a negative number closest to 0 = a - double smax = dataList.get(0).getScore(); + ArrayList> assignedClassesNorm = normClassificationResults(assignedClasses); - double sumLog = 0; - // log(sum(exp(x_n-a))) - for (ClassificationResult cr : dataList) { - // getScore-smax <=0 (both negative, smax is the smallest abs() - sumLog += Math.exp(cr.getScore() - smax); - } - // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n))) - double loga = smax; - loga += Math.log(sumLog); - - // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum)) - for (ClassificationResult cr : dataList) { - returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga))); - } - } - - return returnList; + return assignedClassesNorm; } /** @@ -192,15 +173,15 @@ public class SimpleNaiveBayesClassifier implements Classifier { protected int countDocsWithClass() throws IOException { int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount(); if (docCount == -1) { // in case codec doesn't support getDocCount - TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); + TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector(); BooleanQuery.Builder q = new BooleanQuery.Builder(); q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST)); if (query != null) { q.add(query, BooleanClause.Occur.MUST); } indexSearcher.search(q.build(), - totalHitCountCollector); - docCount = totalHitCountCollector.getTotalHits(); + classQueryCountCollector); + docCount = classQueryCountCollector.getTotalHits(); } return docCount; } @@ -208,14 +189,14 @@ public class SimpleNaiveBayesClassifier implements Classifier { /** * tokenize a String on this classifier's text fields and analyzer * - * @param doc the String representing an input text (to be classified) + * @param text the String representing an input text (to be classified) * @return a String array of the resulting tokens * @throws IOException if tokenization fails */ - protected String[] tokenizeDoc(String doc) throws IOException { + protected String[] tokenize(String text) throws IOException { Collection result = new LinkedList<>(); for (String textFieldName : textFieldNames) { - try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) { + try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) { CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class); tokenStream.reset(); while (tokenStream.incrementToken()) { @@ -227,18 +208,18 @@ public class SimpleNaiveBayesClassifier implements Classifier { return result.toArray(new String[result.size()]); } - private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c, int docsWithClassSize) throws IOException { + private double calculateLogLikelihood(String[] tokenizedText, BytesRef c, int docsWithClass) throws IOException { // for each word double result = 0d; - for (String word : tokenizedDoc) { + for (String word : tokenizedText) { // search with text:word AND class:c - int hits = getWordFreqForClass(word, c); + int hits = getWordFreqForClass(word,c); // num : count the no of times the word appears in documents of class c (+1) double num = hits + 1; // +1 is added because of add 1 smoothing // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|) - double den = getTextTermFreqForClass(c) + docsWithClassSize; + double den = getTextTermFreqForClass(c) + docsWithClass; // P(w|c) = num/den double wordProbability = num / den; @@ -249,6 +230,12 @@ public class SimpleNaiveBayesClassifier implements Classifier { return result; } + /** + * Returns the average number of unique terms times the number of docs belonging to the input class + * @param c the class + * @return the average number of unique terms + * @throws IOException if a low level I/O problem happens + */ private double getTextTermFreqForClass(BytesRef c) throws IOException { double avgNumberOfUniqueTerms = 0; for (String textFieldName : textFieldNames) { @@ -260,6 +247,14 @@ public class SimpleNaiveBayesClassifier implements Classifier { return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c } + /** + * Returns the number of documents of the input class ( from the whole index or from a subset) + * that contains the word ( in a specific field or in all the fields if no one selected) + * @param word the token produced by the analyzer + * @param c the class + * @return the number of documents of the input class + * @throws IOException if a low level I/O problem happens + */ private int getWordFreqForClass(String word, BytesRef c) throws IOException { BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder(); BooleanQuery.Builder subQuery = new BooleanQuery.Builder(); @@ -283,4 +278,36 @@ public class SimpleNaiveBayesClassifier implements Classifier { private int docCount(BytesRef countedClass) throws IOException { return leafReader.docFreq(new Term(classFieldName, countedClass)); } + + /** + * Normalize the classification results based on the max score available + * @param assignedClasses the list of assigned classes + * @return the normalized results + */ + protected ArrayList> normClassificationResults(List> assignedClasses) { + // normalization; the values transforms to a 0-1 range + ArrayList> returnList = new ArrayList<>(); + if (!assignedClasses.isEmpty()) { + Collections.sort(assignedClasses); + // this is a negative number closest to 0 = a + double smax = assignedClasses.get(0).getScore(); + + double sumLog = 0; + // log(sum(exp(x_n-a))) + for (ClassificationResult cr : assignedClasses) { + // getScore-smax <=0 (both negative, smax is the smallest abs() + sumLog += Math.exp(cr.getScore() - smax); + } + // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n))) + double loga = smax; + loga += Math.log(sumLog); + + // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum)) + for (ClassificationResult cr : assignedClasses) { + double scoreDiff = cr.getScore() - loga; + returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff))); + } + } + return returnList; + } } diff --git a/lucene/classification/src/java/org/apache/lucene/classification/document/DocumentClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/document/DocumentClassifier.java new file mode 100644 index 00000000000..2b568795c09 --- /dev/null +++ b/lucene/classification/src/java/org/apache/lucene/classification/document/DocumentClassifier.java @@ -0,0 +1,61 @@ +package org.apache.lucene.classification.document; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.List; + +import org.apache.lucene.classification.ClassificationResult; +import org.apache.lucene.document.Document; + +/** + * A classifier, see http://en.wikipedia.org/wiki/Classifier_(mathematics), which assign classes of type + * T to a {@link org.apache.lucene.document.Document}s + * + * @lucene.experimental + */ +public interface DocumentClassifier { + /** + * Assign a class (with score) to the given {@link org.apache.lucene.document.Document} + * + * @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification. + * @return a {@link org.apache.lucene.classification.ClassificationResult} holding assigned class of type T and score + * @throws java.io.IOException If there is a low-level I/O error. + */ + ClassificationResult assignClass(Document document) throws IOException; + + /** + * Get all the classes (sorted by score, descending) assigned to the given {@link org.apache.lucene.document.Document}. + * + * @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification. + * @return the whole list of {@link org.apache.lucene.classification.ClassificationResult}, the classes and scores. Returns null if the classifier can't make lists. + * @throws java.io.IOException If there is a low-level I/O error. + */ + List> getClasses(Document document) throws IOException; + + /** + * Get the first max classes (sorted by score, descending) assigned to the given text String. + * + * @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification. + * @param max the number of return list elements + * @return the whole list of {@link org.apache.lucene.classification.ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns null if the classifier can't make lists. + * @throws java.io.IOException If there is a low-level I/O error. + */ + List> getClasses(Document document, int max) throws IOException; + +} diff --git a/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java new file mode 100644 index 00000000000..d211b34c1f4 --- /dev/null +++ b/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java @@ -0,0 +1,146 @@ +package org.apache.lucene.classification.document; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.io.StringReader; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.classification.ClassificationResult; +import org.apache.lucene.classification.KNearestNeighborClassifier; +import org.apache.lucene.document.Document; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.WildcardQuery; +import org.apache.lucene.search.similarities.Similarity; +import org.apache.lucene.util.BytesRef; + +/** + * A k-Nearest Neighbor Document classifier (see http://en.wikipedia.org/wiki/K-nearest_neighbors) based + * on {@link org.apache.lucene.queries.mlt.MoreLikeThis} . + * + * @lucene.experimental + */ +public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifier implements DocumentClassifier { + protected Map field2analyzer; + + /** + * Creates a {@link KNearestNeighborClassifier}. + * + * @param leafReader the reader on the index to be used for classification + * @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null} + * (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity}) + * @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null} + * if all the indexed docs should be used + * @param k the no. of docs to select in the MLT results to find the nearest neighbor + * @param minDocsFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq} parameter + * @param minTermFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq} parameter + * @param classFieldName the name of the field used as the output for the classifier + * @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer} + * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10 + */ + public KNearestNeighborDocumentClassifier(LeafReader leafReader, Similarity similarity, Query query, int k, int minDocsFreq, + int minTermFreq, String classFieldName, Map field2analyzer, String... textFieldNames) { + super(leafReader,similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames); + this.field2analyzer = field2analyzer; + } + + /** + * {@inheritDoc} + */ + @Override + public ClassificationResult assignClass(Document document) throws IOException { + TopDocs knnResults = knnSearch(document); + List> assignedClasses = buildListFromTopDocs(knnResults); + ClassificationResult assignedClass = null; + double maxscore = -Double.MAX_VALUE; + for (ClassificationResult cl : assignedClasses) { + if (cl.getScore() > maxscore) { + assignedClass = cl; + maxscore = cl.getScore(); + } + } + return assignedClass; + } + + /** + * {@inheritDoc} + */ + @Override + public List> getClasses(Document document) throws IOException { + TopDocs knnResults = knnSearch(document); + List> assignedClasses = buildListFromTopDocs(knnResults); + Collections.sort(assignedClasses); + return assignedClasses; + } + + /** + * {@inheritDoc} + */ + @Override + public List> getClasses(Document document, int max) throws IOException { + TopDocs knnResults = knnSearch(document); + List> assignedClasses = buildListFromTopDocs(knnResults); + Collections.sort(assignedClasses); + return assignedClasses.subList(0, max); + } + + /** + * Returns the top k results from a More Like This query based on the input document + * + * @param document the document to use for More Like This search + * @return the top results for the MLT query + * @throws IOException If there is a low-level I/O error + */ + private TopDocs knnSearch(Document document) throws IOException { + BooleanQuery.Builder mltQuery = new BooleanQuery.Builder(); + + for (String fieldName : textFieldNames) { + String boost = null; + if (fieldName.contains("^")) { + String[] field2boost = fieldName.split("\\^"); + fieldName = field2boost[0]; + boost = field2boost[1]; + } + String[] fieldValues = document.getValues(fieldName); + if (boost != null) { + mlt.setBoost(true); + mlt.setBoostFactor(Float.parseFloat(boost)); + } + mlt.setAnalyzer(field2analyzer.get(fieldName)); + for (String fieldContent : fieldValues) { + mltQuery.add(new BooleanClause(mlt.like(fieldName, new StringReader(fieldContent)), BooleanClause.Occur.SHOULD)); + } + mlt.setBoost(false); + } + Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*")); + mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST)); + if (query != null) { + mltQuery.add(query, BooleanClause.Occur.MUST); + } + return indexSearcher.search(mltQuery.build(), k); + } +} diff --git a/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java new file mode 100644 index 00000000000..6f1f0da0d50 --- /dev/null +++ b/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java @@ -0,0 +1,289 @@ +package org.apache.lucene.classification.document; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.classification.ClassificationResult; +import org.apache.lucene.classification.SimpleNaiveBayesClassifier; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.MultiFields; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TotalHitCountCollector; +import org.apache.lucene.search.WildcardQuery; +import org.apache.lucene.util.BytesRef; + +/** + * A simplistic Lucene based NaiveBayes classifier, see {@code http://en.wikipedia.org/wiki/Naive_Bayes_classifier} + * + * @lucene.experimental + */ +public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifier implements DocumentClassifier { + /** + * {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing document fields + */ + protected Map field2analyzer; + + /** + * Creates a new NaiveBayes classifier. + * + * @param leafReader the reader on the index to be used for classification + * @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null} + * if all the indexed docs should be used + * @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed + * as the returned class will be a token indexed for this field + * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10 + */ + public SimpleNaiveBayesDocumentClassifier(LeafReader leafReader, Query query, String classFieldName, Map field2analyzer, String... textFieldNames) { + super(leafReader, null, query, classFieldName, textFieldNames); + this.field2analyzer = field2analyzer; + } + + /** + * {@inheritDoc} + */ + @Override + public ClassificationResult assignClass(Document document) throws IOException { + List> assignedClasses = assignNormClasses(document); + ClassificationResult assignedClass = null; + double maxscore = -Double.MAX_VALUE; + for (ClassificationResult c : assignedClasses) { + if (c.getScore() > maxscore) { + assignedClass = c; + maxscore = c.getScore(); + } + } + return assignedClass; + } + + /** + * {@inheritDoc} + */ + @Override + public List> getClasses(Document document) throws IOException { + List> assignedClasses = assignNormClasses(document); + Collections.sort(assignedClasses); + return assignedClasses; + } + + /** + * {@inheritDoc} + */ + @Override + public List> getClasses(Document document, int max) throws IOException { + List> assignedClasses = assignNormClasses(document); + Collections.sort(assignedClasses); + return assignedClasses.subList(0, max); + } + + private List> assignNormClasses(Document inputDocument) throws IOException { + List> assignedClasses = new ArrayList<>(); + Map> fieldName2tokensArray = new LinkedHashMap<>(); + Map fieldName2boost = new LinkedHashMap<>(); + Terms classes = MultiFields.getTerms(leafReader, classFieldName); + TermsEnum classesEnum = classes.iterator(); + BytesRef c; + + analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost); + + int docsWithClassSize = countDocsWithClass(); + while ((c = classesEnum.next()) != null) { + double classScore = 0; + for (String fieldName : textFieldNames) { + List tokensArrays = fieldName2tokensArray.get(fieldName); + double fieldScore = 0; + for (String[] fieldTokensArray : tokensArrays) { + fieldScore += calculateLogPrior(c, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, c, docsWithClassSize) * fieldName2boost.get(fieldName); + } + classScore += fieldScore; + } + assignedClasses.add(new ClassificationResult<>(BytesRef.deepCopyOf(c), classScore)); + } + ArrayList> assignedClassesNorm = normClassificationResults(assignedClasses); + return assignedClassesNorm; + } + + /** + * This methods performs the analysis for the seed document and extract the boosts if present. + * This is done only one time for the Seed Document. + * + * @param inputDocument the seed unseen document + * @param fieldName2tokensArray a map that associated to a field name the list of token arrays for all its values + * @param fieldName2boost a map that associates the boost to the field + * @throws IOException If there is a low-level I/O error + */ + private void analyzeSeedDocument(Document inputDocument, Map> fieldName2tokensArray, Map fieldName2boost) throws IOException { + for (int i = 0; i < textFieldNames.length; i++) { + String fieldName = textFieldNames[i]; + float boost = 1; + List tokenizedValues = new LinkedList<>(); + if (fieldName.contains("^")) { + String[] field2boost = fieldName.split("\\^"); + fieldName = field2boost[0]; + boost = Float.parseFloat(field2boost[1]); + } + Field[] fieldValues = inputDocument.getFields(fieldName); + for (Field fieldValue : fieldValues) { + TokenStream fieldTokens = fieldValue.tokenStream(field2analyzer.get(fieldName), null); + String[] fieldTokensArray = getTokenArray(fieldTokens); + tokenizedValues.add(fieldTokensArray); + } + fieldName2tokensArray.put(fieldName, tokenizedValues); + fieldName2boost.put(fieldName, boost); + textFieldNames[i] = fieldName; + } + } + + /** + * Counts the number of documents in the index having at least a value for the 'class' field + * + * @return the no. of documents having a value for the 'class' field + * @throws java.io.IOException If accessing to term vectors or search fails + */ + protected int countDocsWithClass() throws IOException { + int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount(); + if (docCount == -1) { // in case codec doesn't support getDocCount + TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector(); + BooleanQuery.Builder q = new BooleanQuery.Builder(); + q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST)); + if (query != null) { + q.add(query, BooleanClause.Occur.MUST); + } + indexSearcher.search(q.build(), + classQueryCountCollector); + docCount = classQueryCountCollector.getTotalHits(); + } + return docCount; + } + + /** + * Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input + * + * @param tokenizedText the tokenized content of a field + * @return a {@code String} array of the resulting tokens + * @throws java.io.IOException If tokenization fails because there is a low-level I/O error + */ + protected String[] getTokenArray(TokenStream tokenizedText) throws IOException { + Collection tokens = new LinkedList<>(); + CharTermAttribute charTermAttribute = tokenizedText.addAttribute(CharTermAttribute.class); + tokenizedText.reset(); + while (tokenizedText.incrementToken()) { + tokens.add(charTermAttribute.toString()); + } + tokenizedText.end(); + tokenizedText.close(); + return tokens.toArray(new String[tokens.size()]); + } + + /** + * @param tokenizedText the tokenized content of a field + * @param fieldName the input field name + * @param c the class to calculate the score of + * @param docsWithClass the total number of docs that have a class + * @return a normalized score for the class + * @throws IOException If there is a low-level I/O error + */ + private double calculateLogLikelihood(String[] tokenizedText, String fieldName, BytesRef c, int docsWithClass) throws IOException { + // for each word + double result = 0d; + for (String word : tokenizedText) { + // search with text:word AND class:c + int hits = getWordFreqForClass(word, fieldName, c); + + // num : count the no of times the word appears in documents of class c (+1) + double num = hits + 1; // +1 is added because of add 1 smoothing + + // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|) + double den = getTextTermFreqForClass(c, fieldName) + docsWithClass; + + // P(w|c) = num/den + double wordProbability = num / den; + result += Math.log(wordProbability); + } + + // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c)) + double normScore = result / (tokenizedText.length); // this is normalized because if not, long text fields will always be more important than short fields + return normScore; + } + + /** + * Returns the average number of unique terms times the number of docs belonging to the input class + * + * @param c the class + * @return the average number of unique terms + * @throws java.io.IOException If there is a low-level I/O error + */ + private double getTextTermFreqForClass(BytesRef c, String fieldName) throws IOException { + double avgNumberOfUniqueTerms; + Terms terms = MultiFields.getTerms(leafReader, fieldName); + long numPostings = terms.getSumDocFreq(); // number of term/doc pairs + avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc + int docsWithC = leafReader.docFreq(new Term(classFieldName, c)); + return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c + } + + /** + * Returns the number of documents of the input class ( from the whole index or from a subset) + * that contains the word ( in a specific field or in all the fields if no one selected) + * + * @param word the token produced by the analyzer + * @param fieldName the field the word is coming from + * @param c the class + * @return number of documents of the input class + * @throws java.io.IOException If there is a low-level I/O error + */ + private int getWordFreqForClass(String word, String fieldName, BytesRef c) throws IOException { + BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder(); + BooleanQuery.Builder subQuery = new BooleanQuery.Builder(); + subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD)); + booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST)); + booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST)); + if (query != null) { + booleanQuery.add(query, BooleanClause.Occur.MUST); + } + TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); + indexSearcher.search(booleanQuery.build(), totalHitCountCollector); + return totalHitCountCollector.getTotalHits(); + } + + private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException { + return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize); + } + + private int docCount(BytesRef countedClass) throws IOException { + return leafReader.docFreq(new Term(classFieldName, countedClass)); + } +} diff --git a/lucene/classification/src/java/org/apache/lucene/classification/document/package-info.java b/lucene/classification/src/java/org/apache/lucene/classification/document/package-info.java new file mode 100644 index 00000000000..cf8b8a459ed --- /dev/null +++ b/lucene/classification/src/java/org/apache/lucene/classification/document/package-info.java @@ -0,0 +1,7 @@ +/** + * Uses already seen data (the indexed documents) to classify new documents. + *

+ * Currently contains a (simplistic) Naive Bayes classifier and a k-Nearest + * Neighbor classifier. + */ +package org.apache.lucene.classification.document; diff --git a/lucene/classification/src/java/org/apache/lucene/classification/package-info.java b/lucene/classification/src/java/org/apache/lucene/classification/package-info.java index 9191d7fc948..1a6b4a26ca1 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/package-info.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/package-info.java @@ -16,7 +16,7 @@ */ /** - * Uses already seen data (the indexed documents) to classify new documents. + * Uses already seen data (the indexed documents) to classify an input ( can be simple text or a structured document). *

* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest * Neighbor classifier and a Perceptron based classifier. diff --git a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java index a63868f0a59..605b4905e80 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java @@ -57,8 +57,8 @@ public abstract class ClassificationTestBase extends LuceneTestCase { protected static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology"); protected RandomIndexWriter indexWriter; - private Directory dir; - private FieldType ft; + protected Directory dir; + protected FieldType ft; protected String textFieldName; protected String categoryFieldName; diff --git a/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java new file mode 100644 index 00000000000..766ed24691b --- /dev/null +++ b/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java @@ -0,0 +1,259 @@ +package org.apache.lucene.classification.document; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.en.EnglishAnalyzer; +import org.apache.lucene.classification.ClassificationResult; +import org.apache.lucene.classification.ClassificationTestBase; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.SlowCompositeReaderWrapper; +import org.apache.lucene.util.BytesRef; +import org.junit.Before; + +/** + * Base class for testing {@link org.apache.lucene.classification.Classifier}s + */ +public abstract class DocumentClassificationTestBase extends ClassificationTestBase { + + protected static final BytesRef VIDEOGAME_RESULT = new BytesRef("videogames"); + protected static final BytesRef VIDEOGAME_ANALYZED_RESULT = new BytesRef("videogam"); + protected static final BytesRef BATMAN_RESULT = new BytesRef("batman"); + + protected String titleFieldName = "title"; + protected String authorFieldName = "author"; + + protected Analyzer analyzer; + protected Map field2analyzer; + protected LeafReader leafReader; + + @Before + public void init() throws IOException { + analyzer = new EnglishAnalyzer(); + field2analyzer = new LinkedHashMap<>(); + field2analyzer.put(textFieldName, analyzer); + field2analyzer.put(titleFieldName, analyzer); + field2analyzer.put(authorFieldName, analyzer); + leafReader = populateDocumentClassificationIndex(analyzer); + } + + protected double checkCorrectDocumentClassification(DocumentClassifier classifier, Document inputDoc, T expectedResult) throws Exception { + ClassificationResult classificationResult = classifier.assignClass(inputDoc); + assertNotNull(classificationResult.getAssignedClass()); + assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); + double score = classificationResult.getScore(); + assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0); + return score; + } + + protected LeafReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException { + indexWriter.close(); + indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE)); + indexWriter.commit(); + String text; + String title; + String author; + + Document doc = new Document(); + title = "Video games are an economic business"; + text = "Video games have become an art form and an industry. The video game industry is of increasing" + + " commercial importance, with growth driven particularly by the emerging Asian markets and mobile games." + + " As of 2015, video games generated sales of USD 74 billion annually worldwide, and were the third-largest" + + " segment in the U.S. entertainment market, behind broadcast and cable TV."; + author = "Ign"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(categoryFieldName, "videogames", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc); + + doc = new Document(); + title = "Video games: the definition of fun on PC and consoles"; + text = "A video game is an electronic game that involves human interaction with a user interface to generate" + + " visual feedback on a video device. The word video in video game traditionally referred to a raster display device," + + "[1] but it now implies any type of display device that can produce two- or three-dimensional images." + + " The electronic systems used to play video games are known as platforms; examples of these are personal" + + " computers and video game consoles. These platforms range from large mainframe computers to small handheld devices." + + " Specialized video games such as arcade games, while previously common, have gradually declined in use."; + author = "Ign"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(categoryFieldName, "videogames", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc); + + doc = new Document(); + title = "Video games: the history across PC, consoles and fun"; + text = "Early games used interactive electronic devices with various display formats. The earliest example is" + + " from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," + + " and issued on 14 December 1948, as U.S. Patent 2455992.[2]" + + "Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" + + " dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]"; + author = "Ign"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(categoryFieldName, "videogames", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc); + + doc = new Document(); + title = "Video games: the history"; + text = "Early games used interactive electronic devices with various display formats. The earliest example is" + + " from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," + + " and issued on 14 December 1948, as U.S. Patent 2455992.[2]" + + "Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" + + " dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]"; + author = "Ign"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(categoryFieldName, "videogames", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc); + + doc = new Document(); + title = "Batman: Arkham Knight PC Benchmarks, For What They're Worth"; + text = "Although I didn’t spend much time playing Batman: Arkham Origins, I remember the game rather well after" + + " testing it on no less than 30 graphics cards and 20 CPUs. Arkham Origins appeared to take full advantage of" + + " Unreal Engine 3, it ran smoothly on affordable GPUs, though it’s worth remembering that Origins was developed " + + "for last-gen consoles.This week marked the arrival of Batman: Arkham Knight, the fourth entry in WB’s Batman:" + + " Arkham series and a direct sequel to 2013’s Arkham Origins 2011’s Arkham City." + + "Arkham Knight is also powered by Unreal Engine 3, but you can expect noticeably improved graphics, in part because" + + " the PlayStation 4 and Xbox One have replaced the PS3 and 360 as the lowest common denominator."; + author = "Rocksteady Studios"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(categoryFieldName, "batman", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc); + + doc = new Document(); + title = "Face-Off: Batman: Arkham Knight, the Dark Knight returns!"; + text = "Despite the drama surrounding the PC release leading to its subsequent withdrawal, there's a sense of success" + + " in the console space as PlayStation 4 owners, and indeed those on Xbox One, get a superb rendition of Batman:" + + " Arkham Knight. It's fair to say Rocksteady sized up each console's strengths well ahead of producing its first" + + " current-gen title, and it's paid off in one of the best Batman games we've seen in years."; + author = "Rocksteady Studios"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(categoryFieldName, "batman", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc); + + doc = new Document(); + title = "Batman: Arkham Knight Having More Trouble, But This Time not in Gotham"; + text = "As news began to break about the numerous issues affecting the PC version of Batman: Arkham Knight, players" + + " of the console version breathed a sigh of relief and got back to playing the game. Now players of the PlayStation" + + " 4 version are having problems of their own, albeit much less severe ones." + + "This time Batman will have a difficult time in Gotham."; + author = "Rocksteady Studios"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(categoryFieldName, "batman", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc); + + doc = new Document(); + title = "Batman: Arkham Knight the new legend of Gotham"; + text = "As news began to break about the numerous issues affecting the PC version of the game, players" + + " of the console version breathed a sigh of relief and got back to play. Now players of the PlayStation" + + " 4 version are having problems of their own, albeit much less severe ones."; + author = "Rocksteady Studios"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(categoryFieldName, "batman", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc); + + + doc = new Document(); + text = "unlabeled doc"; + doc.add(new Field(textFieldName, text, ft)); + indexWriter.addDocument(doc); + + indexWriter.commit(); + return SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); + } + + protected Document getVideoGameDocument() { + Document doc = new Document(); + String title = "The new generation of PC and Console Video games"; + String text = "Recently a lot of games have been released for the latest generations of consoles and personal computers." + + "One of them is Batman: Arkham Knight released recently on PS4, X-box and personal computer." + + "Another important video game that will be released in November is Assassin's Creed, a classic series that sees its new installement on Halloween." + + "Recently a lot of problems affected the Assassin's creed series but this time it should ran smoothly on affordable GPUs." + + "Players are waiting for the versions of their favourite video games and so do we."; + String author = "Ign"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + return doc; + } + + protected Document getBatmanDocument() { + Document doc = new Document(); + String title = "Batman: Arkham Knight new adventures for the super hero across Gotham, the Dark Knight has returned!"; + String title2 = "I am a second title !"; + String text = "This game is the electronic version of the famous super hero adventures.It involves the interaction with the open world" + + " of the city of Gotham. Finally the player will be able to have fun on its personal device." + + " The three-dimensional images of the game are stunning, because it uses the Unreal Engine 3." + + " The systems available are PS4, X-Box and personal computer." + + " Will the simulate missile that is going to be fired, success ?\" +\n" + + " Will this video game make the history" + + " Help you favourite super hero to defeat all his enemies. The Dark Knight has returned !"; + String author = "Rocksteady Studios"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(titleFieldName, title2, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + return doc; + } + + protected Document getBatmanAmbiguosDocument() { + Document doc = new Document(); + String title = "Batman: Arkham Knight new adventures for the super hero across Gotham, the Dark Knight has returned! Batman will win !"; + String text = "Early games used interactive electronic devices with various display formats. The earliest example is" + + " from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," + + " and issued on 14 December 1948, as U.S. Patent 2455992.[2]" + + "Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" + + " dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]"; + String author = "Ign"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(titleFieldName, title, ft)); + doc.add(new Field(authorFieldName, author, ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + return doc; + } +} diff --git a/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java new file mode 100644 index 00000000000..10e4dce19e1 --- /dev/null +++ b/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java @@ -0,0 +1,96 @@ +package org.apache.lucene.classification.document; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.document.Document; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.util.BytesRef; +import org.junit.Test; + +/** + * Tests for {@link org.apache.lucene.classification.KNearestNeighborClassifier} + */ +public class KNearestNeighborDocumentClassifierTest extends DocumentClassificationTestBase { + + @Test + public void testBasicDocumentClassification() throws Exception { + try { + Document videoGameDocument = getVideoGameDocument(); + Document batmanDocument = getBatmanDocument(); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT); + // considering only the text we have wrong classification because the text was ambiguos on purpose + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT); + + } finally { + if (leafReader != null) { + leafReader.close(); + } + } + } + + @Test + public void testBasicDocumentClassificationScore() throws Exception { + try { + Document videoGameDocument = getVideoGameDocument(); + Document batmanDocument = getBatmanDocument(); + double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT); + assertEquals(1.0,score1,0); + double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT); + assertEquals(1.0,score2,0); + // considering only the text we have wrong classification because the text was ambiguos on purpose + double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT); + assertEquals(1.0,score3,0); + double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT); + assertEquals(1.0,score4,0); + } finally { + if (leafReader != null) { + leafReader.close(); + } + } + } + + @Test + public void testBoostedDocumentClassification() throws Exception { + try { + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT); + // considering without boost wrong classification will appear + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT); + } finally { + if (leafReader != null) { + leafReader.close(); + } + } + } + + @Test + public void testBasicDocumentClassificationWithQuery() throws Exception { + try { + TermQuery query = new TermQuery(new Term(authorFieldName, "ign")); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT); + } finally { + if (leafReader != null) { + leafReader.close(); + } + } + } + +} diff --git a/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java new file mode 100644 index 00000000000..a1bcb54bf56 --- /dev/null +++ b/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java @@ -0,0 +1,76 @@ +package org.apache.lucene.classification.document; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.util.BytesRef; +import org.junit.Test; + +/** + * Tests for {@link org.apache.lucene.classification.SimpleNaiveBayesClassifier} + */ +public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificationTestBase { + + @Test + public void testBasicDocumentClassification() throws Exception { + try { + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT); + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT); + + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT); + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT); + } finally { + if (leafReader != null) { + leafReader.close(); + } + } + } + + @Test + public void testBasicDocumentClassificationScore() throws Exception { + try { + double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT); + assertEquals(0.88,score1,0.01); + double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT); + assertEquals(0.89,score2,0.01); + //taking in consideration only the text + double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT); + assertEquals(0.55,score3,0.01); + double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT); + assertEquals(0.52,score4,0.01); + } finally { + if (leafReader != null) { + leafReader.close(); + } + } + } + + @Test + public void testBoostedDocumentClassification() throws Exception { + try { + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT); + // considering without boost wrong classification will appear + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT); + } finally { + if (leafReader != null) { + leafReader.close(); + } + } + } + + +}