LUCENE-4959: Fix incorrect return value in SimpleNaiveBayesClassifier.assignClass.

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1476650 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Adrien Grand 2013-04-27 18:09:30 +00:00
parent 60151ce379
commit b242be2680
5 changed files with 16 additions and 7 deletions

View File

@ -65,6 +65,9 @@ Bug Fixes
toString() was changed to better reflect the reader structure.
(Mike McCandless, Uwe Schindler)
* LUCENE-4959: Fix incorrect return value in
SimpleNaiveBayesClassifier.assignClass. (Alexey Kutin via Adrien Grand)
Optimizations
* LUCENE-4938: Don't use an unnecessarily large priority queue in IndexSearcher

View File

@ -117,7 +117,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
double clVal = calculatePrior(next) * calculateLikelihood(tokenizedDoc, next);
if (clVal > max) {
max = clVal;
foundClass = next.clone();
foundClass = BytesRef.deepCopyOf(next);
}
}
return new ClassificationResult<BytesRef>(foundClass, max);

View File

@ -24,6 +24,7 @@ import org.apache.lucene.document.TextField;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.junit.After;
import org.junit.Before;
@ -32,6 +33,11 @@ import org.junit.Before;
* Base class for testing {@link Classifier}s
*/
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more.";
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
private RandomIndexWriter indexWriter;
private String textFieldName;
@ -59,14 +65,13 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
}
protected void checkCorrectClassification(Classifier<T> classifier, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception {
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception {
SlowCompositeReaderWrapper compositeReaderWrapper = null;
try {
populateIndex(analyzer);
compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader());
classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more.";
ClassificationResult<T> classificationResult = classifier.assignClass(newText);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);

View File

@ -27,7 +27,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
@Test
public void testBasicUsage() throws Exception {
checkCorrectClassification(new KNearestNeighborClassifier(1), new BytesRef("technology"), new MockAnalyzer(random()), categoryFieldName);
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
}
}

View File

@ -34,12 +34,13 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
@Test
public void testBasicUsage() throws Exception {
checkCorrectClassification(new SimpleNaiveBayesClassifier(), new BytesRef("technology"), new MockAnalyzer(random()), categoryFieldName);
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), categoryFieldName);
}
@Test
public void testNGramUsage() throws Exception {
checkCorrectClassification(new SimpleNaiveBayesClassifier(), new BytesRef("technology"), new NGramAnalyzer(), categoryFieldName);
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), categoryFieldName);
}
private class NGramAnalyzer extends Analyzer {