LUCENE-7776 - use bm25 for knn classifier

This commit is contained in:
Tommaso Teofili 2017-04-11 10:44:36 +02:00
parent 15a1561d43
commit 0f60c4233c
3 changed files with 13 additions and 17 deletions

View File

@ -38,6 +38,7 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;
@ -106,7 +107,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
if (similarity != null) {
this.indexSearcher.setSimilarity(similarity);
} else {
this.indexSearcher.setSimilarity(new ClassicSimilarity());
this.indexSearcher.setSimilarity(new BM25Similarity());
}
if (minDocsFreq > 0) {
mlt.setMinDocFreq(minDocsFreq);
@ -124,7 +125,10 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
*/
@Override
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
TopDocs knnResults = knnSearch(text);
return classifyFromTopDocs(knnSearch(text));
}
protected ClassificationResult<BytesRef> classifyFromTopDocs(TopDocs knnResults) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
ClassificationResult<BytesRef> assignedClass = null;
double maxscore = -Double.MAX_VALUE;

View File

@ -77,17 +77,7 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi
*/
@Override
public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
TopDocs knnResults = knnSearch(document);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
ClassificationResult<BytesRef> assignedClass = null;
double maxscore = -Double.MAX_VALUE;
for (ClassificationResult<BytesRef> cl : assignedClasses) {
if (cl.getScore() > maxscore) {
assignedClass = cl;
maxscore = cl.getScore();
}
}
return assignedClass;
return classifyFromTopDocs(knnSearch(document));
}
/**

View File

@ -33,8 +33,9 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
try {
Document videoGameDocument = getVideoGameDocument();
Document batmanDocument = getBatmanDocument();
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
KNearestNeighborDocumentClassifier classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, 1, 4, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName});
checkCorrectDocumentClassification(classifier, videoGameDocument, VIDEOGAME_RESULT);
checkCorrectDocumentClassification(classifier, batmanDocument, BATMAN_RESULT);
// considering only the text we have wrong classification because the text was ambiguos on purpose
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
@ -51,9 +52,10 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
try {
Document videoGameDocument = getVideoGameDocument();
Document batmanDocument = getBatmanDocument();
double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
KNearestNeighborDocumentClassifier classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, 1, 4, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName});
double score1 = checkCorrectDocumentClassification(classifier, videoGameDocument, VIDEOGAME_RESULT);
assertEquals(1.0,score1,0);
double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
double score2 = checkCorrectDocumentClassification(classifier, batmanDocument, BATMAN_RESULT);
assertEquals(1.0,score2,0);
// considering only the text we have wrong classification because the text was ambiguos on purpose
double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);