mirror of https://github.com/apache/lucene.git
LUCENE-7776 - use bm25 for knn classifier
This commit is contained in:
parent
15a1561d43
commit
0f60c4233c
|
@ -38,6 +38,7 @@ import org.apache.lucene.search.Query;
|
|||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||
import org.apache.lucene.search.similarities.ClassicSimilarity;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
@ -106,7 +107,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
if (similarity != null) {
|
||||
this.indexSearcher.setSimilarity(similarity);
|
||||
} else {
|
||||
this.indexSearcher.setSimilarity(new ClassicSimilarity());
|
||||
this.indexSearcher.setSimilarity(new BM25Similarity());
|
||||
}
|
||||
if (minDocsFreq > 0) {
|
||||
mlt.setMinDocFreq(minDocsFreq);
|
||||
|
@ -124,7 +125,10 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
|
||||
TopDocs knnResults = knnSearch(text);
|
||||
return classifyFromTopDocs(knnSearch(text));
|
||||
}
|
||||
|
||||
protected ClassificationResult<BytesRef> classifyFromTopDocs(TopDocs knnResults) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||
ClassificationResult<BytesRef> assignedClass = null;
|
||||
double maxscore = -Double.MAX_VALUE;
|
||||
|
|
|
@ -77,17 +77,7 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi
|
|||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
|
||||
TopDocs knnResults = knnSearch(document);
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||
ClassificationResult<BytesRef> assignedClass = null;
|
||||
double maxscore = -Double.MAX_VALUE;
|
||||
for (ClassificationResult<BytesRef> cl : assignedClasses) {
|
||||
if (cl.getScore() > maxscore) {
|
||||
assignedClass = cl;
|
||||
maxscore = cl.getScore();
|
||||
}
|
||||
}
|
||||
return assignedClass;
|
||||
return classifyFromTopDocs(knnSearch(document));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -33,8 +33,9 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
|
|||
try {
|
||||
Document videoGameDocument = getVideoGameDocument();
|
||||
Document batmanDocument = getBatmanDocument();
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||
KNearestNeighborDocumentClassifier classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, 1, 4, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName});
|
||||
checkCorrectDocumentClassification(classifier, videoGameDocument, VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(classifier, batmanDocument, BATMAN_RESULT);
|
||||
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||
|
@ -51,9 +52,10 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
|
|||
try {
|
||||
Document videoGameDocument = getVideoGameDocument();
|
||||
Document batmanDocument = getBatmanDocument();
|
||||
double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||
KNearestNeighborDocumentClassifier classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, 1, 4, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName});
|
||||
double score1 = checkCorrectDocumentClassification(classifier, videoGameDocument, VIDEOGAME_RESULT);
|
||||
assertEquals(1.0,score1,0);
|
||||
double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||
double score2 = checkCorrectDocumentClassification(classifier, batmanDocument, BATMAN_RESULT);
|
||||
assertEquals(1.0,score2,0);
|
||||
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
||||
double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||
|
|
Loading…
Reference in New Issue