LUCENE-7838 - added knn classifier based on flt

This commit is contained in:
Tommaso Teofili 2017-05-18 14:35:53 +02:00
parent afd70a48cc
commit bd9e32d358
6 changed files with 417 additions and 68 deletions

View File

@ -16,8 +16,9 @@
<orderEntry type="module" scope="TEST" module-name="lucene-test-framework" />
<orderEntry type="module" module-name="lucene-core" />
<orderEntry type="module" module-name="queries" />
<orderEntry type="module" scope="TEST" module-name="analysis-common" />
<orderEntry type="module" module-name="analysis-common" />
<orderEntry type="module" module-name="grouping" />
<orderEntry type="module" module-name="misc" />
<orderEntry type="module" module-name="sandbox" />
</component>
</module>

View File

@ -28,6 +28,8 @@
<path refid="base.classpath"/>
<pathelement path="${queries.jar}"/>
<pathelement path="${grouping.jar}"/>
<pathelement path="${sandbox.jar}"/>
<pathelement path="${analyzers-common.jar}"/>
</path>
<path id="test.classpath">
@ -36,16 +38,18 @@
<path refid="test.base.classpath"/>
</path>
<target name="compile-core" depends="jar-grouping,jar-queries,jar-analyzers-common,common.compile-core" />
<target name="compile-core" depends="jar-sandbox,jar-grouping,jar-queries,jar-analyzers-common,common.compile-core" />
<target name="jar-core" depends="common.jar-core" />
<target name="javadocs" depends="javadocs-grouping,compile-core,check-javadocs-uptodate"
<target name="javadocs" depends="javadocs-sandbox,javadocs-grouping,compile-core,check-javadocs-uptodate"
unless="javadocs-uptodate-${name}">
<invoke-module-javadoc>
<links>
<link href="../queries"/>
<link href="../analyzers/common"/>
<link href="../grouping"/>
<link href="../sandbox"/>
</links>
</invoke-module-javadoc>
</target>

View File

@ -0,0 +1,225 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.sandbox.queries.FuzzyLikeThisQuery;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;
/**
* A k-Nearest Neighbor classifier based on {@link FuzzyLikeThisQuery}.
*
* @lucene.experimental
*/
public class KNearestFuzzyClassifier implements Classifier<BytesRef> {
/**
* the name of the fields used as the input text
*/
protected final String[] textFieldNames;
/**
* the name of the field used as the output text
*/
protected final String classFieldName;
/**
* an {@link IndexSearcher} used to perform queries
*/
protected final IndexSearcher indexSearcher;
/**
* the no. of docs to compare in order to find the nearest neighbor to the input text
*/
protected final int k;
/**
* a {@link Query} used to filter the documents that should be used from this classifier's underlying {@link LeafReader}
*/
protected final Query query;
private final Analyzer analyzer;
/**
* Creates a {@link KNearestFuzzyClassifier}.
*
* @param indexReader the reader on the index to be used for classification
* @param analyzer an {@link Analyzer} used to analyze unseen text
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
* (defaults to {@link BM25Similarity})
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used
* @param k the no. of docs to select in the MLT results to find the nearest neighbor
* @param classFieldName the name of the field used as the output for the classifier
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
*/
public KNearestFuzzyClassifier(IndexReader indexReader, Similarity similarity, Analyzer analyzer, Query query, int k,
String classFieldName, String... textFieldNames) {
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
this.analyzer = analyzer;
this.indexSearcher = new IndexSearcher(indexReader);
if (similarity != null) {
this.indexSearcher.setSimilarity(similarity);
} else {
this.indexSearcher.setSimilarity(new BM25Similarity());
}
this.query = query;
this.k = k;
}
/**
* {@inheritDoc}
*/
@Override
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
TopDocs knnResults = knnSearch(text);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
ClassificationResult<BytesRef> assignedClass = null;
double maxscore = -Double.MAX_VALUE;
for (ClassificationResult<BytesRef> cl : assignedClasses) {
if (cl.getScore() > maxscore) {
assignedClass = cl;
maxscore = cl.getScore();
}
}
return assignedClass;
}
/**
* {@inheritDoc}
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
TopDocs knnResults = knnSearch(text);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
return assignedClasses;
}
/**
* {@inheritDoc}
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
TopDocs knnResults = knnSearch(text);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
return assignedClasses.subList(0, max);
}
private TopDocs knnSearch(String text) throws IOException {
BooleanQuery.Builder bq = new BooleanQuery.Builder();
FuzzyLikeThisQuery fuzzyLikeThisQuery = new FuzzyLikeThisQuery(300, analyzer);
for (String fieldName : textFieldNames) {
fuzzyLikeThisQuery.addTerms(text, fieldName, 1f, 2); // TODO: make this parameters configurable
}
bq.add(fuzzyLikeThisQuery, BooleanClause.Occur.MUST);
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
bq.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
if (query != null) {
bq.add(query, BooleanClause.Occur.MUST);
}
return indexSearcher.search(bq.build(), k);
}
/**
* build a list of classification results from search results
*
* @param topDocs the search results as a {@link TopDocs} object
* @return a {@link List} of {@link ClassificationResult}, one for each existing class
* @throws IOException if it's not possible to get the stored value of class field
*/
protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
Map<BytesRef, Integer> classCounts = new HashMap<>();
Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
float maxScore = topDocs.getMaxScore();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
IndexableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
if (storableField != null) {
BytesRef cl = new BytesRef(storableField.stringValue());
//update count
Integer count = classCounts.get(cl);
if (count != null) {
classCounts.put(cl, count + 1);
} else {
classCounts.put(cl, 1);
}
//update boost, the boost is based on the best score
Double totalBoost = classBoosts.get(cl);
double singleBoost = scoreDoc.score / maxScore;
if (totalBoost != null) {
classBoosts.put(cl, totalBoost + singleBoost);
} else {
classBoosts.put(cl, singleBoost);
}
}
}
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
int sumdoc = 0;
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
Integer count = entry.getValue();
Double normBoost = classBoosts.get(entry.getKey()) / count; //the boost is normalized to be 0<b<1
temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), (count * normBoost) / (double) k));
sumdoc += count;
}
//correction
if (sumdoc < k) {
for (ClassificationResult<BytesRef> cr : temporaryList) {
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
}
} else {
returnList = temporaryList;
}
return returnList;
}
@Override
public String toString() {
return "KNearestFuzzyClassifier{" +
"textFieldNames=" + Arrays.toString(textFieldNames) +
", classFieldName='" + classFieldName + '\'' +
", k=" + k +
", query=" + query +
", similarity=" + indexSearcher.getSimilarity(true) +
'}';
}
}

View File

@ -121,7 +121,7 @@ public class DatasetSplitter {
int b = 0;
// iterate over existing documents
for (GroupDocs group : topGroups.groups) {
for (GroupDocs<Object> group : topGroups.groups) {
int totalHits = group.totalHits;
double testSize = totalHits * testRatio;
int tc = 0;

View File

@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.core.KeywordTokenizer;
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.junit.Test;
/**
* Testcase for {@link KNearestFuzzyClassifier}
*/
public class KNearestFuzzyClassifierTest extends ClassificationTestBase<BytesRef> {
@Test
public void testBasicUsage() throws Exception {
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null, analyzer, null, 3, categoryFieldName, textFieldName);
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(classifier, POLITICS_INPUT, POLITICS_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testBasicUsageWithQuery() throws Exception {
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = getSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "not"));
Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null, analyzer, query, 3, categoryFieldName, textFieldName);
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testPerformance() throws Exception {
MockAnalyzer analyzer = new MockAnalyzer(random());
LeafReader leafReader = getRandomIndex(analyzer, 100);
try {
long trainStart = System.currentTimeMillis();
Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null, analyzer, null, 3, categoryFieldName, textFieldName);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
long evaluationStart = System.currentTimeMillis();
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
classifier, categoryFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
long evaluationEnd = System.currentTimeMillis();
long evaluationTime = evaluationEnd - evaluationStart;
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(5000 > avgClassificationTime);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
TermsEnum iterator = terms.iterator();
BytesRef term;
while ((term = iterator.next()) != null) {
String s = term.utf8ToString();
recall = confusionMatrix.getRecall(s);
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
precision = confusionMatrix.getPrecision(s);
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double f1Measure = confusionMatrix.getF1Measure(s);
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
}
} finally {
leafReader.close();
}
}
}

View File

@ -21,11 +21,13 @@ import java.io.IOException;
import java.util.List;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.classification.BM25NBClassifier;
import org.apache.lucene.classification.BooleanPerceptronClassifier;
import org.apache.lucene.classification.CachingNaiveBayesClassifier;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.ClassificationTestBase;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.classification.KNearestFuzzyClassifier;
import org.apache.lucene.classification.KNearestNeighborClassifier;
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
import org.apache.lucene.index.LeafReader;
@ -94,6 +96,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
checkCM(confusionMatrix);
} finally {
if (reader != null) {
reader.close();
}
}
}
private void checkCM(ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix) {
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@ -110,6 +121,18 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
}
@Test
public void testGetConfusionMatrixWithBM25NB() throws Exception {
LeafReader reader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new BM25NBClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
checkCM(confusionMatrix);
} finally {
if (reader != null) {
reader.close();
@ -126,22 +149,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
checkCM(confusionMatrix);
} finally {
if (reader != null) {
reader.close();
@ -158,22 +166,24 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
checkCM(confusionMatrix);
} finally {
if (reader != null) {
reader.close();
}
}
}
@Test
public void testGetConfusionMatrixWithFLTKNN() throws Exception {
LeafReader reader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(reader, null, analyzer, null, 1, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
checkCM(confusionMatrix);
} finally {
if (reader != null) {
reader.close();
@ -190,22 +200,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, booleanFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
checkCM(confusionMatrix);
assertTrue(confusionMatrix.getPrecision("true") >= 0d);
assertTrue(confusionMatrix.getPrecision("true") <= 1d);
assertTrue(confusionMatrix.getPrecision("false") >= 0d);