LUCENE-6731 - added Similarity parameter to kNN

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1695034 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-08-10 11:50:10 +00:00
parent 5186ce4e23
commit 42c00ed22b
4 changed files with 22 additions and 8 deletions

View File

@ -36,6 +36,8 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.search.similarities.DefaultSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;
/**
@ -58,6 +60,8 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
*
* @param leafReader the reader on the index to be used for classification
* @param analyzer an {@link Analyzer} used to analyze unseen text
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
* (defaults to {@link org.apache.lucene.search.similarities.DefaultSimilarity})
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used
* @param k the no. of docs to select in the MLT results to find the nearest neighbor
@ -66,7 +70,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
* @param classFieldName the name of the field used as the output for the classifier
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
*/
public KNearestNeighborClassifier(LeafReader leafReader, Analyzer analyzer, Query query, int k, int minDocsFreq,
public KNearestNeighborClassifier(LeafReader leafReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq,
int minTermFreq, String classFieldName, String... textFieldNames) {
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
@ -74,6 +78,11 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
this.mlt.setAnalyzer(analyzer);
this.mlt.setFieldNames(textFieldNames);
this.indexSearcher = new IndexSearcher(leafReader);
if (similarity != null) {
this.indexSearcher.setSimilarity(similarity);
} else {
this.indexSearcher.setSimilarity(new DefaultSimilarity());
}
if (minDocsFreq > 0) {
mlt.setMinDocFreq(minDocsFreq);
}

View File

@ -87,12 +87,13 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
dir.close();
}
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult) throws Exception {
protected ClassificationResult<T> checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult) throws Exception {
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
double score = classificationResult.getScore();
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
return classificationResult;
}
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {

View File

@ -24,6 +24,7 @@ import org.apache.lucene.analysis.en.EnglishAnalyzer;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.similarities.LMDirichletSimilarity;
import org.apache.lucene.util.BytesRef;
import org.junit.Test;
@ -38,8 +39,11 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, new LMDirichletSimilarity(), analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
ClassificationResult<BytesRef> resultDS = checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
ClassificationResult<BytesRef> resultLMS = checkCorrectClassification(new KNearestNeighborClassifier(leafReader, new LMDirichletSimilarity(), analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
assertTrue(resultDS.getScore() != resultLMS.getScore());
} finally {
if (leafReader != null) {
leafReader.close();
@ -60,7 +64,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
try {
Analyzer analyzer = new EnglishAnalyzer();
leafReader = populateSampleIndex(analyzer);
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, analyzer, null, 6, 1, 1, categoryFieldName, textFieldName);
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, null, analyzer, null, 6, 1, 1, categoryFieldName, textFieldName);
List<ClassificationResult<BytesRef>> classes = knnClassifier.getClasses(STRONG_TECHNOLOGY_INPUT);
assertTrue(classes.get(0).getScore() > classes.get(1).getScore());
checkCorrectClassification(knnClassifier, STRONG_TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
@ -85,7 +89,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
try {
Analyzer analyzer = new EnglishAnalyzer();
leafReader = populateSampleIndex(analyzer);
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, analyzer, null, 3, 1, 1, categoryFieldName, textFieldName);
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, null,analyzer, null, 3, 1, 1, categoryFieldName, textFieldName);
List<ClassificationResult<BytesRef>> classes = knnClassifier.getClasses(SUPER_STRONG_TECHNOLOGY_INPUT);
assertTrue(classes.get(0).getScore() > classes.get(1).getScore());
checkCorrectClassification(knnClassifier, SUPER_STRONG_TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
@ -103,7 +107,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();

View File

@ -73,7 +73,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase {
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = populateSampleIndex(analyzer);
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());