diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index 9d3625d5385..e8069eea47a 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -58,6 +58,9 @@ public class KNearestNeighborClassifier implements Classifier { */ @Override public ClassificationResult assignClass(String text) throws IOException { + if (mlt == null) { + throw new IOException("You must first call Classifier#train first"); + } Query q = mlt.like(new StringReader(text), textFieldName); TopDocs topDocs = indexSearcher.search(q, k); return selectClassFromNeighbors(topDocs); diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index 06d5b83aeb0..6617b139b37 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -103,7 +103,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { @Override public ClassificationResult assignClass(String inputDocument) throws IOException { if (atomicReader == null) { - throw new RuntimeException("need to train the classifier first"); + throw new IOException("You must first call Classifier#train first"); } double max = 0d; BytesRef foundClass = new BytesRef(); diff --git a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java index 8d7e1f332bf..2225603f6d2 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java @@ -24,7 +24,6 @@ import org.apache.lucene.document.TextField; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.SlowCompositeReaderWrapper; import org.apache.lucene.store.Directory; -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; import org.junit.After; import org.junit.Before; @@ -32,12 +31,13 @@ import org.junit.Before; /** * Base class for testing {@link Classifier}s */ -public abstract class ClassificationTestBase extends LuceneTestCase { +public abstract class ClassificationTestBase extends LuceneTestCase { private RandomIndexWriter indexWriter; private String textFieldName; - private String classFieldName; private Directory dir; + String categoryFieldName; + String booleanFieldName; @Override @Before @@ -46,7 +46,8 @@ public abstract class ClassificationTestBase extends LuceneTestCase { dir = newDirectory(); indexWriter = new RandomIndexWriter(random(), dir); textFieldName = "text"; - classFieldName = "cat"; + categoryFieldName = "cat"; + booleanFieldName = "bool"; } @Override @@ -58,17 +59,17 @@ public abstract class ClassificationTestBase extends LuceneTestCase { } - protected void checkCorrectClassification(Classifier classifier, Analyzer analyzer) throws Exception { + protected void checkCorrectClassification(Classifier classifier, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception { SlowCompositeReaderWrapper compositeReaderWrapper = null; try { populateIndex(analyzer); compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader()); classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer); String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more."; - ClassificationResult classificationResult = classifier.assignClass(newText); + ClassificationResult classificationResult = classifier.assignClass(newText); assertNotNull(classificationResult.getAssignedClass()); - assertEquals(new BytesRef("technology"), classificationResult.getAssignedClass()); - assertTrue(classificationResult.getScore() > 0); + assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); + assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0); } finally { if (compositeReaderWrapper != null) compositeReaderWrapper.close(); @@ -86,48 +87,55 @@ public abstract class ClassificationTestBase extends LuceneTestCase { doc.add(new Field(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " + "who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " + "the Unknown Soldier in Warsaw Tuesday.", ft)); - doc.add(new Field(classFieldName, "politics", ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" + " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", ft)); - doc.add(new Field(classFieldName, "politics", ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "And there's a threshold question that he has to answer for the American people and " + "that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " + "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", ft)); - doc.add(new Field(classFieldName, "politics", ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " + "keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " + "Albany's School of Criminal Justice.", ft)); - doc.add(new Field(classFieldName, "politics", ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " + "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " + "world through the Internet.", ft)); - doc.add(new Field(classFieldName, "technology", ft)); + doc.add(new Field(categoryFieldName, "technology", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " + "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", ft)); - doc.add(new Field(classFieldName, "technology", ft)); + doc.add(new Field(categoryFieldName, "technology", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" + " in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " + "generally transfer or store huge volumes of personal data online.", ft)); - doc.add(new Field(classFieldName, "technology", ft)); + doc.add(new Field(categoryFieldName, "technology", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); indexWriter.addDocument(doc, analyzer); indexWriter.commit(); diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java index 6bc5402dcf0..c80f45a90a5 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java @@ -17,16 +17,17 @@ package org.apache.lucene.classification; import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.util.BytesRef; import org.junit.Test; /** * Testcase for {@link KNearestNeighborClassifier} */ -public class KNearestNeighborClassifierTest extends ClassificationTestBase { +public class KNearestNeighborClassifierTest extends ClassificationTestBase { @Test public void testBasicUsage() throws Exception { - checkCorrectClassification(new KNearestNeighborClassifier(1), new MockAnalyzer(random())); + checkCorrectClassification(new KNearestNeighborClassifier(1), new BytesRef("technology"), new MockAnalyzer(random()), categoryFieldName); } } diff --git a/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java index e203a2027fd..40341587b42 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java @@ -19,6 +19,7 @@ package org.apache.lucene.classification; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.ngram.EdgeNGramTokenizer; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; import org.junit.Test; @@ -29,16 +30,16 @@ import java.io.Reader; */ // TODO : eventually remove this if / when fallback methods exist for all un-supportable codec methods (see LUCENE-4872) @LuceneTestCase.SuppressCodecs("Lucene3x") -public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase { +public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase { @Test public void testBasicUsage() throws Exception { - checkCorrectClassification(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random())); + checkCorrectClassification(new SimpleNaiveBayesClassifier(), new BytesRef("technology"), new MockAnalyzer(random()), categoryFieldName); } @Test public void testNGramUsage() throws Exception { - checkCorrectClassification(new SimpleNaiveBayesClassifier(), new NGramAnalyzer()); + checkCorrectClassification(new SimpleNaiveBayesClassifier(), new BytesRef("technology"), new NGramAnalyzer(), categoryFieldName); } private class NGramAnalyzer extends Analyzer {