mirror of https://github.com/apache/lucene.git
LUCENE-6731 - added Similarity parameter to kNN
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1695034 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
5186ce4e23
commit
42c00ed22b
|
@ -36,6 +36,8 @@ import org.apache.lucene.search.Query;
|
|||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.search.similarities.DefaultSimilarity;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
|
@ -58,6 +60,8 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
*
|
||||
* @param leafReader the reader on the index to be used for classification
|
||||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
|
||||
* (defaults to {@link org.apache.lucene.search.similarities.DefaultSimilarity})
|
||||
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||
* if all the indexed docs should be used
|
||||
* @param k the no. of docs to select in the MLT results to find the nearest neighbor
|
||||
|
@ -66,7 +70,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
* @param classFieldName the name of the field used as the output for the classifier
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||
*/
|
||||
public KNearestNeighborClassifier(LeafReader leafReader, Analyzer analyzer, Query query, int k, int minDocsFreq,
|
||||
public KNearestNeighborClassifier(LeafReader leafReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq,
|
||||
int minTermFreq, String classFieldName, String... textFieldNames) {
|
||||
this.textFieldNames = textFieldNames;
|
||||
this.classFieldName = classFieldName;
|
||||
|
@ -74,6 +78,11 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
this.mlt.setAnalyzer(analyzer);
|
||||
this.mlt.setFieldNames(textFieldNames);
|
||||
this.indexSearcher = new IndexSearcher(leafReader);
|
||||
if (similarity != null) {
|
||||
this.indexSearcher.setSimilarity(similarity);
|
||||
} else {
|
||||
this.indexSearcher.setSimilarity(new DefaultSimilarity());
|
||||
}
|
||||
if (minDocsFreq > 0) {
|
||||
mlt.setMinDocFreq(minDocsFreq);
|
||||
}
|
||||
|
|
|
@ -87,12 +87,13 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
dir.close();
|
||||
}
|
||||
|
||||
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult) throws Exception {
|
||||
protected ClassificationResult<T> checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult) throws Exception {
|
||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||
assertNotNull(classificationResult.getAssignedClass());
|
||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||
double score = classificationResult.getScore();
|
||||
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
|
||||
return classificationResult;
|
||||
}
|
||||
|
||||
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.lucene.analysis.en.EnglishAnalyzer;
|
|||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.similarities.LMDirichletSimilarity;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -38,8 +39,11 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, new LMDirichletSimilarity(), analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
ClassificationResult<BytesRef> resultDS = checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
ClassificationResult<BytesRef> resultLMS = checkCorrectClassification(new KNearestNeighborClassifier(leafReader, new LMDirichletSimilarity(), analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
assertTrue(resultDS.getScore() != resultLMS.getScore());
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
|
@ -60,7 +64,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
try {
|
||||
Analyzer analyzer = new EnglishAnalyzer();
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, analyzer, null, 6, 1, 1, categoryFieldName, textFieldName);
|
||||
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, null, analyzer, null, 6, 1, 1, categoryFieldName, textFieldName);
|
||||
List<ClassificationResult<BytesRef>> classes = knnClassifier.getClasses(STRONG_TECHNOLOGY_INPUT);
|
||||
assertTrue(classes.get(0).getScore() > classes.get(1).getScore());
|
||||
checkCorrectClassification(knnClassifier, STRONG_TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
|
@ -85,7 +89,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
try {
|
||||
Analyzer analyzer = new EnglishAnalyzer();
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, analyzer, null, 3, 1, 1, categoryFieldName, textFieldName);
|
||||
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, null,analyzer, null, 3, 1, 1, categoryFieldName, textFieldName);
|
||||
List<ClassificationResult<BytesRef>> classes = knnClassifier.getClasses(SUPER_STRONG_TECHNOLOGY_INPUT);
|
||||
assertTrue(classes.get(0).getScore() > classes.get(1).getScore());
|
||||
checkCorrectClassification(knnClassifier, SUPER_STRONG_TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
|
@ -103,7 +107,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
|
|
|
@ -73,7 +73,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase {
|
|||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = populateSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
|
||||
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
|
|
Loading…
Reference in New Issue