mirror of https://github.com/apache/lucene.git
LUCENE-4959: Fix incorrect return value in SimpleNaiveBayesClassifier.assignClass.
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1476650 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
60151ce379
commit
b242be2680
|
@ -65,6 +65,9 @@ Bug Fixes
|
|||
toString() was changed to better reflect the reader structure.
|
||||
(Mike McCandless, Uwe Schindler)
|
||||
|
||||
* LUCENE-4959: Fix incorrect return value in
|
||||
SimpleNaiveBayesClassifier.assignClass. (Alexey Kutin via Adrien Grand)
|
||||
|
||||
Optimizations
|
||||
|
||||
* LUCENE-4938: Don't use an unnecessarily large priority queue in IndexSearcher
|
||||
|
|
|
@ -117,7 +117,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
double clVal = calculatePrior(next) * calculateLikelihood(tokenizedDoc, next);
|
||||
if (clVal > max) {
|
||||
max = clVal;
|
||||
foundClass = next.clone();
|
||||
foundClass = BytesRef.deepCopyOf(next);
|
||||
}
|
||||
}
|
||||
return new ClassificationResult<BytesRef>(foundClass, max);
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.lucene.document.TextField;
|
|||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
@ -32,6 +33,11 @@ import org.junit.Before;
|
|||
* Base class for testing {@link Classifier}s
|
||||
*/
|
||||
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
|
||||
public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
|
||||
|
||||
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more.";
|
||||
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
||||
|
||||
private RandomIndexWriter indexWriter;
|
||||
private String textFieldName;
|
||||
|
@ -59,14 +65,13 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
}
|
||||
|
||||
|
||||
protected void checkCorrectClassification(Classifier<T> classifier, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception {
|
||||
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception {
|
||||
SlowCompositeReaderWrapper compositeReaderWrapper = null;
|
||||
try {
|
||||
populateIndex(analyzer);
|
||||
compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader());
|
||||
classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
|
||||
String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more.";
|
||||
ClassificationResult<T> classificationResult = classifier.assignClass(newText);
|
||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||
assertNotNull(classificationResult.getAssignedClass());
|
||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
|
||||
|
|
|
@ -27,7 +27,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
|
||||
@Test
|
||||
public void testBasicUsage() throws Exception {
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(1), new BytesRef("technology"), new MockAnalyzer(random()), categoryFieldName);
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -34,12 +34,13 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
|||
|
||||
@Test
|
||||
public void testBasicUsage() throws Exception {
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), new BytesRef("technology"), new MockAnalyzer(random()), categoryFieldName);
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), categoryFieldName);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNGramUsage() throws Exception {
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), new BytesRef("technology"), new NGramAnalyzer(), categoryFieldName);
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), categoryFieldName);
|
||||
}
|
||||
|
||||
private class NGramAnalyzer extends Analyzer {
|
||||
|
|
Loading…
Reference in New Issue