mirror of https://github.com/apache/lucene.git
LUCENE-7823 - added bm25 nb classifier
This commit is contained in:
parent
fb56948e70
commit
8990500183
|
@ -0,0 +1,243 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.classification;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* A classifier approximating naive bayes classifier by using pure queries on BM25.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class BM25NBClassifier implements Classifier<BytesRef> {
|
||||
|
||||
/**
|
||||
* {@link IndexReader} used to access the {@link Classifier}'s
|
||||
* index
|
||||
*/
|
||||
private final IndexReader indexReader;
|
||||
|
||||
/**
|
||||
* names of the fields to be used as input text
|
||||
*/
|
||||
private final String[] textFieldNames;
|
||||
|
||||
/**
|
||||
* name of the field to be used as a class / category output
|
||||
*/
|
||||
private final String classFieldName;
|
||||
|
||||
/**
|
||||
* {@link Analyzer} to be used for tokenizing unseen input text
|
||||
*/
|
||||
private final Analyzer analyzer;
|
||||
|
||||
/**
|
||||
* {@link IndexSearcher} to run searches on the index for retrieving frequencies
|
||||
*/
|
||||
private final IndexSearcher indexSearcher;
|
||||
|
||||
/**
|
||||
* {@link Query} used to eventually filter the document set to be used to classify
|
||||
*/
|
||||
private final Query query;
|
||||
|
||||
/**
|
||||
* Creates a new NaiveBayes classifier.
|
||||
*
|
||||
* @param indexReader the reader on the index to be used for classification
|
||||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||
* if all the indexed docs should be used
|
||||
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
|
||||
* as the returned class will be a token indexed for this field
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
|
||||
*/
|
||||
public BM25NBClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
||||
this.indexReader = indexReader;
|
||||
this.indexSearcher = new IndexSearcher(this.indexReader);
|
||||
this.indexSearcher.setSimilarity(new BM25Similarity());
|
||||
this.textFieldNames = textFieldNames;
|
||||
this.classFieldName = classFieldName;
|
||||
this.analyzer = analyzer;
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
|
||||
return assignClassNormalizedList(inputDocument).get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
|
||||
Collections.sort(assignedClasses);
|
||||
return assignedClasses;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
|
||||
Collections.sort(assignedClasses);
|
||||
return assignedClasses.subList(0, max);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate probabilities for all classes for a given input text
|
||||
*
|
||||
* @param inputDocument the input text as a {@code String}
|
||||
* @return a {@code List} of {@code ClassificationResult}, one for each existing class
|
||||
* @throws IOException if assigning probabilities fails
|
||||
*/
|
||||
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||
|
||||
Terms classes = MultiFields.getTerms(indexReader, classFieldName);
|
||||
TermsEnum classesEnum = classes.iterator();
|
||||
BytesRef next;
|
||||
String[] tokenizedText = tokenize(inputDocument);
|
||||
while ((next = classesEnum.next()) != null) {
|
||||
if (next.length > 0) {
|
||||
Term term = new Term(this.classFieldName, next);
|
||||
assignedClasses.add(new ClassificationResult<>(term.bytes(), calculateLogPrior(term) + calculateLogLikelihood(tokenizedText, term)));
|
||||
}
|
||||
}
|
||||
|
||||
return normClassificationResults(assignedClasses);
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize the classification results based on the max score available
|
||||
*
|
||||
* @param assignedClasses the list of assigned classes
|
||||
* @return the normalized results
|
||||
*/
|
||||
private ArrayList<ClassificationResult<BytesRef>> normClassificationResults(List<ClassificationResult<BytesRef>> assignedClasses) {
|
||||
// normalization; the values transforms to a 0-1 range
|
||||
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||
if (!assignedClasses.isEmpty()) {
|
||||
Collections.sort(assignedClasses);
|
||||
// this is a negative number closest to 0 = a
|
||||
double smax = assignedClasses.get(0).getScore();
|
||||
|
||||
double sumLog = 0;
|
||||
// log(sum(exp(x_n-a)))
|
||||
for (ClassificationResult<BytesRef> cr : assignedClasses) {
|
||||
// getScore-smax <=0 (both negative, smax is the smallest abs()
|
||||
sumLog += Math.exp(cr.getScore() - smax);
|
||||
}
|
||||
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
|
||||
double loga = smax;
|
||||
loga += Math.log(sumLog);
|
||||
|
||||
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
|
||||
for (ClassificationResult<BytesRef> cr : assignedClasses) {
|
||||
double scoreDiff = cr.getScore() - loga;
|
||||
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff)));
|
||||
}
|
||||
}
|
||||
return returnList;
|
||||
}
|
||||
|
||||
/**
|
||||
* tokenize a <code>String</code> on this classifier's text fields and analyzer
|
||||
*
|
||||
* @param text the <code>String</code> representing an input text (to be classified)
|
||||
* @return a <code>String</code> array of the resulting tokens
|
||||
* @throws IOException if tokenization fails
|
||||
*/
|
||||
private String[] tokenize(String text) throws IOException {
|
||||
Collection<String> result = new LinkedList<>();
|
||||
for (String textFieldName : textFieldNames) {
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
||||
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
|
||||
tokenStream.reset();
|
||||
while (tokenStream.incrementToken()) {
|
||||
result.add(charTermAttribute.toString());
|
||||
}
|
||||
tokenStream.end();
|
||||
}
|
||||
}
|
||||
return result.toArray(new String[result.size()]);
|
||||
}
|
||||
|
||||
private double calculateLogLikelihood(String[] tokens, Term term) throws IOException {
|
||||
double result = 0d;
|
||||
for (String word : tokens) {
|
||||
result += Math.log(getTermProbForClass(term, word));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private double getTermProbForClass(Term classTerm, String... words) throws IOException {
|
||||
BooleanQuery.Builder builder = new BooleanQuery.Builder();
|
||||
builder.add(new BooleanClause(new TermQuery(classTerm), BooleanClause.Occur.MUST));
|
||||
for (String textFieldName : textFieldNames) {
|
||||
for (String word : words) {
|
||||
builder.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
|
||||
}
|
||||
}
|
||||
if (query != null) {
|
||||
builder.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
TopDocs search = indexSearcher.search(builder.build(), 1);
|
||||
return search.totalHits > 0 ? search.getMaxScore() : 1;
|
||||
}
|
||||
|
||||
private double calculateLogPrior(Term term) throws IOException {
|
||||
TermQuery termQuery = new TermQuery(term);
|
||||
BooleanQuery.Builder bq = new BooleanQuery.Builder();
|
||||
bq.add(termQuery, BooleanClause.Occur.MUST);
|
||||
if (query != null) {
|
||||
bq.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
TopDocs topDocs = indexSearcher.search(bq.build(), 1);
|
||||
return topDocs.totalHits > 0 ? Math.log(topDocs.getMaxScore()) : 0;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.classification;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.analysis.Tokenizer;
|
||||
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
||||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Tests for {@link BM25NBClassifier}
|
||||
*/
|
||||
public class BM25NBClassifierTest extends ClassificationTestBase<BytesRef> {
|
||||
|
||||
@Test
|
||||
public void testBasicUsage() throws Exception {
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
BM25NBClassifier classifier = new BM25NBClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName);
|
||||
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicUsageWithQuery() throws Exception {
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
TermQuery query = new TermQuery(new Term(textFieldName, "not"));
|
||||
BM25NBClassifier classifier = new BM25NBClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName);
|
||||
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNGramUsage() throws Exception {
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
Analyzer analyzer = new NGramAnalyzer();
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
BM25NBClassifier classifier = new BM25NBClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName);
|
||||
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class NGramAnalyzer extends Analyzer {
|
||||
@Override
|
||||
protected TokenStreamComponents createComponents(String fieldName) {
|
||||
final Tokenizer tokenizer = new KeywordTokenizer();
|
||||
return new TokenStreamComponents(tokenizer, new ReverseStringFilter(new EdgeNGramTokenFilter(new ReverseStringFilter(tokenizer), 10, 20)));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
LeafReader leafReader = getRandomIndex(analyzer, 100);
|
||||
try {
|
||||
long trainStart = System.currentTimeMillis();
|
||||
BM25NBClassifier classifier = new BM25NBClassifier(leafReader,
|
||||
analyzer, null, categoryFieldName, textFieldName);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
|
||||
|
||||
long evaluationStart = System.currentTimeMillis();
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
long evaluationEnd = System.currentTimeMillis();
|
||||
long evaluationTime = evaluationEnd - evaluationStart;
|
||||
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime);
|
||||
|
||||
double f1 = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1 >= 0d);
|
||||
assertTrue(f1 <= 1d);
|
||||
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
|
||||
Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
|
||||
TermsEnum iterator = terms.iterator();
|
||||
BytesRef term;
|
||||
while ((term = iterator.next()) != null) {
|
||||
String s = term.utf8ToString();
|
||||
recall = confusionMatrix.getRecall(s);
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
precision = confusionMatrix.getPrecision(s);
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure(s);
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
}
|
||||
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue