LUCENE-5348 - added minDoc/TermFreq params to kNN classifier

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1544435 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2013-11-22 08:36:22 +00:00
parent f9b3e389b2
commit 7b9ca4745a
3 changed files with 164 additions and 58 deletions

View File

@ -49,6 +49,9 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
private final int k;
private Query query;
private int minDocsFreq;
private int minTermFreq;
/**
* Create a {@link Classifier} using kNN algorithm
*
@ -58,6 +61,19 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
this.k = k;
}
/**
* Create a {@link Classifier} using kNN algorithm
*
* @param k the number of neighbors to analyze as an <code>int</code>
* @param minDocsFreq the minimum number of docs frequency for MLT to be set with {@link MoreLikeThis#setMinDocFreq(int)}
* @param minTermFreq the minimum number of term frequency for MLT to be set with {@link MoreLikeThis#setMinTermFreq(int)}
*/
public KNearestNeighborClassifier(int k, int minDocsFreq, int minTermFreq) {
this.k = k;
this.minDocsFreq = minDocsFreq;
this.minTermFreq = minTermFreq;
}
/**
* {@inheritDoc}
*/
@ -93,11 +109,11 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
}
double max = 0;
BytesRef assignedClass = new BytesRef();
for (BytesRef cl : classCounts.keySet()) {
Integer count = classCounts.get(cl);
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
Integer count = entry.getValue();
if (count > max) {
max = count;
assignedClass = cl.clone();
assignedClass = entry.getKey().clone();
}
}
double score = max / (double) k;
@ -117,13 +133,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
*/
@Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textFieldNames = new String[]{textFieldName};
this.classFieldName = classFieldName;
mlt = new MoreLikeThis(atomicReader);
mlt.setAnalyzer(analyzer);
mlt.setFieldNames(new String[]{textFieldName});
indexSearcher = new IndexSearcher(atomicReader);
this.query = query;
train(atomicReader, new String[]{textFieldName}, classFieldName, analyzer, query);
}
/**
@ -137,6 +147,12 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
mlt.setAnalyzer(analyzer);
mlt.setFieldNames(textFieldNames);
indexSearcher = new IndexSearcher(atomicReader);
if (minDocsFreq > 0) {
mlt.setMinDocFreq(minDocsFreq);
}
if (minTermFreq > 0) {
mlt.setMinTermFreq(minTermFreq);
}
this.query = query;
}
}

View File

@ -39,14 +39,17 @@ import java.util.Random;
* Base class for testing {@link Classifier}s
*/
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more.";
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
" Truth is, Amazon may know more.";
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
private RandomIndexWriter indexWriter;
private Directory dir;
private FieldType ft;
String textFieldName;
String categoryFieldName;
@ -61,6 +64,10 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
textFieldName = "text";
categoryFieldName = "cat";
booleanFieldName = "bool";
ft = new FieldType(TextField.TYPE_STORED);
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
}
@Override
@ -90,63 +97,35 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
atomicReader.close();
}
}
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
checkOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
}
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
AtomicReader atomicReader = null;
long trainStart = System.currentTimeMillis();
try {
populatePerformanceIndex(analyzer);
populateSampleIndex(analyzer);
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
classifier.train(atomicReader, textFieldName, classFieldName, analyzer, query);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
updateSampleIndex(analyzer);
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
assertEquals(Double.valueOf(classificationResult.getScore()), Double.valueOf(secondClassificationResult.getScore()));
} finally {
if (atomicReader != null)
atomicReader.close();
}
}
private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
private void populateSampleIndex(Analyzer analyzer) throws IOException {
indexWriter.deleteAll();
indexWriter.commit();
FieldType ft = new FieldType(TextField.TYPE_STORED);
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
int docs = 1000;
Random random = random();
for (int i = 0; i < docs; i++) {
boolean b = random.nextBoolean();
Document doc = new Document();
doc.add(new Field(textFieldName, createRandomString(random), ft));
doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
indexWriter.addDocument(doc, analyzer);
}
indexWriter.commit();
}
private String createRandomString(Random random) {
StringBuilder builder = new StringBuilder();
for (int i = 0; i < 20; i++) {
builder.append(_TestUtil.randomSimpleString(random, 5));
builder.append(" ");
}
return builder.toString();
}
private void populateSampleIndex(Analyzer analyzer) throws Exception {
indexWriter.deleteAll();
indexWriter.commit();
FieldType ft = new FieldType(TextField.TYPE_STORED);
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
String text;
Document doc = new Document();
@ -218,4 +197,112 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
indexWriter.commit();
}
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
AtomicReader atomicReader = null;
long trainStart = System.currentTimeMillis();
try {
populatePerformanceIndex(analyzer);
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
} finally {
if (atomicReader != null)
atomicReader.close();
}
}
private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
indexWriter.deleteAll();
indexWriter.commit();
FieldType ft = new FieldType(TextField.TYPE_STORED);
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
int docs = 1000;
Random random = random();
for (int i = 0; i < docs; i++) {
boolean b = random.nextBoolean();
Document doc = new Document();
doc.add(new Field(textFieldName, createRandomString(random), ft));
doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
indexWriter.addDocument(doc, analyzer);
}
indexWriter.commit();
}
private String createRandomString(Random random) {
StringBuilder builder = new StringBuilder();
for (int i = 0; i < 20; i++) {
builder.append(_TestUtil.randomSimpleString(random, 5));
builder.append(" ");
}
return builder.toString();
}
private void updateSampleIndex(Analyzer analyzer) throws Exception {
String text;
Document doc = new Document();
text = "Warren Bennis says John F. Kennedy grasped a key lesson about the presidency that few have followed.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Julian Zelizer says Bill Clinton is still trying to shape his party, years after the White House, while George W. Bush opts for a much more passive role.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Crossfire: Sen. Tim Scott passes on Sen. Lindsey Graham endorsement";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Illinois becomes 16th state to allow same-sex marriage.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Apple is developing iPhones with curved-glass screens and enhanced sensors that detect different levels of pressure, according to a new report.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "The Xbox One is Microsoft's first new gaming console in eight years. It's a quality piece of hardware but it's also noteworthy because Microsoft is using it to make a statement.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Google says it will replace a Google Maps image after a California father complained it shows the body of his teen-age son, who was shot to death in 2009.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "second unlabeled doc";
doc.add(new Field(textFieldName, text, ft));
indexWriter.addDocument(doc, analyzer);
indexWriter.commit();
}
}

View File

@ -29,7 +29,10 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
@Test
public void testBasicUsage() throws Exception {
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
// usage with default MLT min docs / term freq
checkCorrectClassification(new KNearestNeighborClassifier(3), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
// usage without custom min docs / term freq for MLT
checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
}
@Test