diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index af675ba23d3..b916043820d 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -65,6 +65,9 @@ Bug Fixes toString() was changed to better reflect the reader structure. (Mike McCandless, Uwe Schindler) +* LUCENE-4959: Fix incorrect return value in + SimpleNaiveBayesClassifier.assignClass. (Alexey Kutin via Adrien Grand) + Optimizations * LUCENE-4938: Don't use an unnecessarily large priority queue in IndexSearcher diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index 6617b139b37..692fe4f8492 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -117,7 +117,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { double clVal = calculatePrior(next) * calculateLikelihood(tokenizedDoc, next); if (clVal > max) { max = clVal; - foundClass = next.clone(); + foundClass = BytesRef.deepCopyOf(next); } } return new ClassificationResult(foundClass, max); diff --git a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java index 2225603f6d2..3b08e43fb2f 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java @@ -24,6 +24,7 @@ import org.apache.lucene.document.TextField; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.SlowCompositeReaderWrapper; import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; import org.junit.After; import org.junit.Before; @@ -32,6 +33,11 @@ import org.junit.Before; * Base class for testing {@link Classifier}s */ public abstract class ClassificationTestBase extends LuceneTestCase { + public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section."; + public static final BytesRef POLITICS_RESULT = new BytesRef("politics"); + + public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more."; + public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology"); private RandomIndexWriter indexWriter; private String textFieldName; @@ -59,14 +65,13 @@ public abstract class ClassificationTestBase extends LuceneTestCase { } - protected void checkCorrectClassification(Classifier classifier, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception { + protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception { SlowCompositeReaderWrapper compositeReaderWrapper = null; try { populateIndex(analyzer); compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader()); classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer); - String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more."; - ClassificationResult classificationResult = classifier.assignClass(newText); + ClassificationResult classificationResult = classifier.assignClass(inputDoc); assertNotNull(classificationResult.getAssignedClass()); assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0); diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java index c80f45a90a5..2e2b066576c 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java @@ -27,7 +27,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase