diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index 77f04164cc3..fdcef140921 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -38,6 +38,7 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.WildcardQuery; +import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.util.BytesRef; @@ -106,7 +107,7 @@ public class KNearestNeighborClassifier implements Classifier { if (similarity != null) { this.indexSearcher.setSimilarity(similarity); } else { - this.indexSearcher.setSimilarity(new ClassicSimilarity()); + this.indexSearcher.setSimilarity(new BM25Similarity()); } if (minDocsFreq > 0) { mlt.setMinDocFreq(minDocsFreq); @@ -124,7 +125,10 @@ public class KNearestNeighborClassifier implements Classifier { */ @Override public ClassificationResult assignClass(String text) throws IOException { - TopDocs knnResults = knnSearch(text); + return classifyFromTopDocs(knnSearch(text)); + } + + protected ClassificationResult classifyFromTopDocs(TopDocs knnResults) throws IOException { List> assignedClasses = buildListFromTopDocs(knnResults); ClassificationResult assignedClass = null; double maxscore = -Double.MAX_VALUE; diff --git a/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java index e01090a9cac..adcb13b6a27 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java @@ -77,17 +77,7 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi */ @Override public ClassificationResult assignClass(Document document) throws IOException { - TopDocs knnResults = knnSearch(document); - List> assignedClasses = buildListFromTopDocs(knnResults); - ClassificationResult assignedClass = null; - double maxscore = -Double.MAX_VALUE; - for (ClassificationResult cl : assignedClasses) { - if (cl.getScore() > maxscore) { - assignedClass = cl; - maxscore = cl.getScore(); - } - } - return assignedClass; + return classifyFromTopDocs(knnSearch(document)); } /** diff --git a/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java index 8c885fb9c72..a323724e53d 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java @@ -33,8 +33,9 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati try { Document videoGameDocument = getVideoGameDocument(); Document batmanDocument = getBatmanDocument(); - checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT); - checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT); + KNearestNeighborDocumentClassifier classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, 1, 4, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}); + checkCorrectDocumentClassification(classifier, videoGameDocument, VIDEOGAME_RESULT); + checkCorrectDocumentClassification(classifier, batmanDocument, BATMAN_RESULT); // considering only the text we have wrong classification because the text was ambiguos on purpose checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT); checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT); @@ -51,9 +52,10 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati try { Document videoGameDocument = getVideoGameDocument(); Document batmanDocument = getBatmanDocument(); - double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT); + KNearestNeighborDocumentClassifier classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, 1, 4, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}); + double score1 = checkCorrectDocumentClassification(classifier, videoGameDocument, VIDEOGAME_RESULT); assertEquals(1.0,score1,0); - double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT); + double score2 = checkCorrectDocumentClassification(classifier, batmanDocument, BATMAN_RESULT); assertEquals(1.0,score2,0); // considering only the text we have wrong classification because the text was ambiguos on purpose double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);