mirror of https://github.com/apache/lucene.git
LUCENE-4917 - allowing ClassifierTestBase to be used not only for BytesRef classifiers
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1465575 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
e0cd09fe00
commit
ceaf4996c0
|
@ -58,6 +58,9 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
|
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
|
||||||
|
if (mlt == null) {
|
||||||
|
throw new IOException("You must first call Classifier#train first");
|
||||||
|
}
|
||||||
Query q = mlt.like(new StringReader(text), textFieldName);
|
Query q = mlt.like(new StringReader(text), textFieldName);
|
||||||
TopDocs topDocs = indexSearcher.search(q, k);
|
TopDocs topDocs = indexSearcher.search(q, k);
|
||||||
return selectClassFromNeighbors(topDocs);
|
return selectClassFromNeighbors(topDocs);
|
||||||
|
|
|
@ -103,7 +103,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
@Override
|
@Override
|
||||||
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
|
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
|
||||||
if (atomicReader == null) {
|
if (atomicReader == null) {
|
||||||
throw new RuntimeException("need to train the classifier first");
|
throw new IOException("You must first call Classifier#train first");
|
||||||
}
|
}
|
||||||
double max = 0d;
|
double max = 0d;
|
||||||
BytesRef foundClass = new BytesRef();
|
BytesRef foundClass = new BytesRef();
|
||||||
|
|
|
@ -24,7 +24,6 @@ import org.apache.lucene.document.TextField;
|
||||||
import org.apache.lucene.index.RandomIndexWriter;
|
import org.apache.lucene.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.util.BytesRef;
|
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -32,12 +31,13 @@ import org.junit.Before;
|
||||||
/**
|
/**
|
||||||
* Base class for testing {@link Classifier}s
|
* Base class for testing {@link Classifier}s
|
||||||
*/
|
*/
|
||||||
public abstract class ClassificationTestBase extends LuceneTestCase {
|
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
|
|
||||||
private RandomIndexWriter indexWriter;
|
private RandomIndexWriter indexWriter;
|
||||||
private String textFieldName;
|
private String textFieldName;
|
||||||
private String classFieldName;
|
|
||||||
private Directory dir;
|
private Directory dir;
|
||||||
|
String categoryFieldName;
|
||||||
|
String booleanFieldName;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@Before
|
@Before
|
||||||
|
@ -46,7 +46,8 @@ public abstract class ClassificationTestBase extends LuceneTestCase {
|
||||||
dir = newDirectory();
|
dir = newDirectory();
|
||||||
indexWriter = new RandomIndexWriter(random(), dir);
|
indexWriter = new RandomIndexWriter(random(), dir);
|
||||||
textFieldName = "text";
|
textFieldName = "text";
|
||||||
classFieldName = "cat";
|
categoryFieldName = "cat";
|
||||||
|
booleanFieldName = "bool";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -58,17 +59,17 @@ public abstract class ClassificationTestBase extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected void checkCorrectClassification(Classifier<BytesRef> classifier, Analyzer analyzer) throws Exception {
|
protected void checkCorrectClassification(Classifier<T> classifier, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception {
|
||||||
SlowCompositeReaderWrapper compositeReaderWrapper = null;
|
SlowCompositeReaderWrapper compositeReaderWrapper = null;
|
||||||
try {
|
try {
|
||||||
populateIndex(analyzer);
|
populateIndex(analyzer);
|
||||||
compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader());
|
compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader());
|
||||||
classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
|
classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
|
||||||
String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more.";
|
String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more.";
|
||||||
ClassificationResult<BytesRef> classificationResult = classifier.assignClass(newText);
|
ClassificationResult<T> classificationResult = classifier.assignClass(newText);
|
||||||
assertNotNull(classificationResult.getAssignedClass());
|
assertNotNull(classificationResult.getAssignedClass());
|
||||||
assertEquals(new BytesRef("technology"), classificationResult.getAssignedClass());
|
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||||
assertTrue(classificationResult.getScore() > 0);
|
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
|
||||||
} finally {
|
} finally {
|
||||||
if (compositeReaderWrapper != null)
|
if (compositeReaderWrapper != null)
|
||||||
compositeReaderWrapper.close();
|
compositeReaderWrapper.close();
|
||||||
|
@ -86,48 +87,55 @@ public abstract class ClassificationTestBase extends LuceneTestCase {
|
||||||
doc.add(new Field(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
|
doc.add(new Field(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
|
||||||
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
|
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
|
||||||
"the Unknown Soldier in Warsaw Tuesday.", ft));
|
"the Unknown Soldier in Warsaw Tuesday.", ft));
|
||||||
doc.add(new Field(classFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
|
doc.add(new Field(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
|
||||||
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", ft));
|
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", ft));
|
||||||
doc.add(new Field(classFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "And there's a threshold question that he has to answer for the American people and " +
|
doc.add(new Field(textFieldName, "And there's a threshold question that he has to answer for the American people and " +
|
||||||
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
|
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
|
||||||
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", ft));
|
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", ft));
|
||||||
doc.add(new Field(classFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
|
doc.add(new Field(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
|
||||||
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
|
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
|
||||||
"Albany's School of Criminal Justice.", ft));
|
"Albany's School of Criminal Justice.", ft));
|
||||||
doc.add(new Field(classFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
|
doc.add(new Field(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
|
||||||
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
|
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
|
||||||
"world through the Internet.", ft));
|
"world through the Internet.", ft));
|
||||||
doc.add(new Field(classFieldName, "technology", ft));
|
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " +
|
doc.add(new Field(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " +
|
||||||
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", ft));
|
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", ft));
|
||||||
doc.add(new Field(classFieldName, "technology", ft));
|
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" +
|
doc.add(new Field(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" +
|
||||||
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
|
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
|
||||||
"generally transfer or store huge volumes of personal data online.", ft));
|
"generally transfer or store huge volumes of personal data online.", ft));
|
||||||
doc.add(new Field(classFieldName, "technology", ft));
|
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
indexWriter.commit();
|
indexWriter.commit();
|
||||||
|
|
|
@ -17,16 +17,17 @@
|
||||||
package org.apache.lucene.classification;
|
package org.apache.lucene.classification;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.MockAnalyzer;
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Testcase for {@link KNearestNeighborClassifier}
|
* Testcase for {@link KNearestNeighborClassifier}
|
||||||
*/
|
*/
|
||||||
public class KNearestNeighborClassifierTest extends ClassificationTestBase {
|
public class KNearestNeighborClassifierTest extends ClassificationTestBase<BytesRef> {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsage() throws Exception {
|
public void testBasicUsage() throws Exception {
|
||||||
checkCorrectClassification(new KNearestNeighborClassifier(1), new MockAnalyzer(random()));
|
checkCorrectClassification(new KNearestNeighborClassifier(1), new BytesRef("technology"), new MockAnalyzer(random()), categoryFieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.classification;
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.analysis.MockAnalyzer;
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenizer;
|
import org.apache.lucene.analysis.ngram.EdgeNGramTokenizer;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
@ -29,16 +30,16 @@ import java.io.Reader;
|
||||||
*/
|
*/
|
||||||
// TODO : eventually remove this if / when fallback methods exist for all un-supportable codec methods (see LUCENE-4872)
|
// TODO : eventually remove this if / when fallback methods exist for all un-supportable codec methods (see LUCENE-4872)
|
||||||
@LuceneTestCase.SuppressCodecs("Lucene3x")
|
@LuceneTestCase.SuppressCodecs("Lucene3x")
|
||||||
public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase {
|
public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<BytesRef> {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsage() throws Exception {
|
public void testBasicUsage() throws Exception {
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random()));
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(), new BytesRef("technology"), new MockAnalyzer(random()), categoryFieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNGramUsage() throws Exception {
|
public void testNGramUsage() throws Exception {
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), new NGramAnalyzer());
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(), new BytesRef("technology"), new NGramAnalyzer(), categoryFieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
private class NGramAnalyzer extends Analyzer {
|
private class NGramAnalyzer extends Analyzer {
|
||||||
|
|
Loading…
Reference in New Issue