diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index e21e670d2d5..22d530ac87d 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -49,6 +49,9 @@ public class KNearestNeighborClassifier implements Classifier { private final int k; private Query query; + private int minDocsFreq; + private int minTermFreq; + /** * Create a {@link Classifier} using kNN algorithm * @@ -58,6 +61,19 @@ public class KNearestNeighborClassifier implements Classifier { this.k = k; } + /** + * Create a {@link Classifier} using kNN algorithm + * + * @param k the number of neighbors to analyze as an int + * @param minDocsFreq the minimum number of docs frequency for MLT to be set with {@link MoreLikeThis#setMinDocFreq(int)} + * @param minTermFreq the minimum number of term frequency for MLT to be set with {@link MoreLikeThis#setMinTermFreq(int)} + */ + public KNearestNeighborClassifier(int k, int minDocsFreq, int minTermFreq) { + this.k = k; + this.minDocsFreq = minDocsFreq; + this.minTermFreq = minTermFreq; + } + /** * {@inheritDoc} */ @@ -93,11 +109,11 @@ public class KNearestNeighborClassifier implements Classifier { } double max = 0; BytesRef assignedClass = new BytesRef(); - for (BytesRef cl : classCounts.keySet()) { - Integer count = classCounts.get(cl); + for (Map.Entry entry : classCounts.entrySet()) { + Integer count = entry.getValue(); if (count > max) { max = count; - assignedClass = cl.clone(); + assignedClass = entry.getKey().clone(); } } double score = max / (double) k; @@ -117,13 +133,7 @@ public class KNearestNeighborClassifier implements Classifier { */ @Override public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException { - this.textFieldNames = new String[]{textFieldName}; - this.classFieldName = classFieldName; - mlt = new MoreLikeThis(atomicReader); - mlt.setAnalyzer(analyzer); - mlt.setFieldNames(new String[]{textFieldName}); - indexSearcher = new IndexSearcher(atomicReader); - this.query = query; + train(atomicReader, new String[]{textFieldName}, classFieldName, analyzer, query); } /** @@ -137,6 +147,12 @@ public class KNearestNeighborClassifier implements Classifier { mlt.setAnalyzer(analyzer); mlt.setFieldNames(textFieldNames); indexSearcher = new IndexSearcher(atomicReader); + if (minDocsFreq > 0) { + mlt.setMinDocFreq(minDocsFreq); + } + if (minTermFreq > 0) { + mlt.setMinTermFreq(minTermFreq); + } this.query = query; } } diff --git a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java index cd488043fd5..f8de59fa90f 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java @@ -39,14 +39,17 @@ import java.util.Random; * Base class for testing {@link Classifier}s */ public abstract class ClassificationTestBase extends LuceneTestCase { - public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section."; + public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " + + "If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section."; public static final BytesRef POLITICS_RESULT = new BytesRef("politics"); - public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more."; + public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." + + " Truth is, Amazon may know more."; public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology"); private RandomIndexWriter indexWriter; private Directory dir; + private FieldType ft; String textFieldName; String categoryFieldName; @@ -61,6 +64,10 @@ public abstract class ClassificationTestBase extends LuceneTestCase { textFieldName = "text"; categoryFieldName = "cat"; booleanFieldName = "bool"; + ft = new FieldType(TextField.TYPE_STORED); + ft.setStoreTermVectors(true); + ft.setStoreTermVectorOffsets(true); + ft.setStoreTermVectorPositions(true); } @Override @@ -72,7 +79,7 @@ public abstract class ClassificationTestBase extends LuceneTestCase { } protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception { - checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null); + checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null); } protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception { @@ -90,63 +97,35 @@ public abstract class ClassificationTestBase extends LuceneTestCase { atomicReader.close(); } } + protected void checkOnlineClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception { + checkOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null); + } - protected void checkPerformance(Classifier classifier, Analyzer analyzer, String classFieldName) throws Exception { + protected void checkOnlineClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception { AtomicReader atomicReader = null; - long trainStart = System.currentTimeMillis(); try { - populatePerformanceIndex(analyzer); + populateSampleIndex(analyzer); atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); - classifier.train(atomicReader, textFieldName, classFieldName, analyzer); - long trainEnd = System.currentTimeMillis(); - long trainTime = trainEnd - trainStart; - assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000); + classifier.train(atomicReader, textFieldName, classFieldName, analyzer, query); + ClassificationResult classificationResult = classifier.assignClass(inputDoc); + assertNotNull(classificationResult.getAssignedClass()); + assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); + assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0); + updateSampleIndex(analyzer); + ClassificationResult secondClassificationResult = classifier.assignClass(inputDoc); + assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass()); + assertEquals(Double.valueOf(classificationResult.getScore()), Double.valueOf(secondClassificationResult.getScore())); + } finally { if (atomicReader != null) atomicReader.close(); } } - private void populatePerformanceIndex(Analyzer analyzer) throws IOException { + private void populateSampleIndex(Analyzer analyzer) throws IOException { indexWriter.deleteAll(); indexWriter.commit(); - FieldType ft = new FieldType(TextField.TYPE_STORED); - ft.setStoreTermVectors(true); - ft.setStoreTermVectorOffsets(true); - ft.setStoreTermVectorPositions(true); - int docs = 1000; - Random random = random(); - for (int i = 0; i < docs; i++) { - boolean b = random.nextBoolean(); - Document doc = new Document(); - doc.add(new Field(textFieldName, createRandomString(random), ft)); - doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft)); - doc.add(new Field(booleanFieldName, String.valueOf(b), ft)); - indexWriter.addDocument(doc, analyzer); - } - indexWriter.commit(); - } - - private String createRandomString(Random random) { - StringBuilder builder = new StringBuilder(); - for (int i = 0; i < 20; i++) { - builder.append(_TestUtil.randomSimpleString(random, 5)); - builder.append(" "); - } - return builder.toString(); - } - - private void populateSampleIndex(Analyzer analyzer) throws Exception { - - indexWriter.deleteAll(); - indexWriter.commit(); - - FieldType ft = new FieldType(TextField.TYPE_STORED); - ft.setStoreTermVectors(true); - ft.setStoreTermVectorOffsets(true); - ft.setStoreTermVectorPositions(true); - String text; Document doc = new Document(); @@ -218,4 +197,112 @@ public abstract class ClassificationTestBase extends LuceneTestCase { indexWriter.commit(); } + + protected void checkPerformance(Classifier classifier, Analyzer analyzer, String classFieldName) throws Exception { + AtomicReader atomicReader = null; + long trainStart = System.currentTimeMillis(); + try { + populatePerformanceIndex(analyzer); + atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); + classifier.train(atomicReader, textFieldName, classFieldName, analyzer); + long trainEnd = System.currentTimeMillis(); + long trainTime = trainEnd - trainStart; + assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000); + } finally { + if (atomicReader != null) + atomicReader.close(); + } + } + + private void populatePerformanceIndex(Analyzer analyzer) throws IOException { + indexWriter.deleteAll(); + indexWriter.commit(); + + FieldType ft = new FieldType(TextField.TYPE_STORED); + ft.setStoreTermVectors(true); + ft.setStoreTermVectorOffsets(true); + ft.setStoreTermVectorPositions(true); + int docs = 1000; + Random random = random(); + for (int i = 0; i < docs; i++) { + boolean b = random.nextBoolean(); + Document doc = new Document(); + doc.add(new Field(textFieldName, createRandomString(random), ft)); + doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft)); + doc.add(new Field(booleanFieldName, String.valueOf(b), ft)); + indexWriter.addDocument(doc, analyzer); + } + indexWriter.commit(); + } + + private String createRandomString(Random random) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < 20; i++) { + builder.append(_TestUtil.randomSimpleString(random, 5)); + builder.append(" "); + } + return builder.toString(); + } + + private void updateSampleIndex(Analyzer analyzer) throws Exception { + + String text; + + Document doc = new Document(); + text = "Warren Bennis says John F. Kennedy grasped a key lesson about the presidency that few have followed."; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); + + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + text = "Julian Zelizer says Bill Clinton is still trying to shape his party, years after the White House, while George W. Bush opts for a much more passive role."; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + text = "Crossfire: Sen. Tim Scott passes on Sen. Lindsey Graham endorsement"; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + text = "Illinois becomes 16th state to allow same-sex marriage."; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + text = "Apple is developing iPhones with curved-glass screens and enhanced sensors that detect different levels of pressure, according to a new report."; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(categoryFieldName, "technology", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + text = "The Xbox One is Microsoft's first new gaming console in eight years. It's a quality piece of hardware but it's also noteworthy because Microsoft is using it to make a statement."; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(categoryFieldName, "technology", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + text = "Google says it will replace a Google Maps image after a California father complained it shows the body of his teen-age son, who was shot to death in 2009."; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(categoryFieldName, "technology", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + text = "second unlabeled doc"; + doc.add(new Field(textFieldName, text, ft)); + indexWriter.addDocument(doc, analyzer); + + indexWriter.commit(); + } } diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java index 7e754adb560..2a0308286f2 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java @@ -29,7 +29,10 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase