LUCENE-6479 - improved cm testing, added stats, minor fixes

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1700914 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-09-02 22:21:53 +00:00
parent e16e914057
commit 8b2e0d937d
10 changed files with 294 additions and 72 deletions

View File

@ -173,7 +173,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
// update weights
Long previousValue = Util.get(fst, term);
String termString = term.utf8ToString();
weights.put(termString, previousValue + modifier * termFreqLocal);
weights.put(termString, previousValue == null ? 0 : previousValue + modifier * termFreqLocal);
}
}
if (updateFST) {
@ -214,6 +214,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
}
}
tokenStream.end();
tokenStream.close();
}
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);

View File

@ -80,7 +80,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
}
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
String[] tokenizedDoc = tokenizeDoc(inputDocument);
List<ClassificationResult<BytesRef>> dataList = calculateLogLikelihood(tokenizedDoc);
@ -200,7 +200,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
}
}
if (insertPoint != null) {
// threadsafe and concurent write
// threadsafe and concurrent write
termCClassHitCache.put(word, searched);
}
}

View File

@ -134,7 +134,13 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return doclist.subList(0, max);
}
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
/**
* Calculate probabilities for all classes for a given input text
* @param inputDocument the input text as a {@code String}
* @return a {@code List} of {@code ClassificationResult}, one for each existing class
* @throws IOException if assigning probabilities fails
*/
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();
Terms terms = MultiFields.getTerms(leafReader, classFieldName);
@ -143,8 +149,10 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
String[] tokenizedDoc = tokenizeDoc(inputDocument);
int docsWithClassSize = countDocsWithClass();
while ((next = termsEnum.next()) != null) {
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal));
if (next.length > 0) {
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal));
}
}
// normalization; the values transforms to a 0-1 range
@ -212,6 +220,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
result.add(charTermAttribute.toString());
}
tokenStream.end();
tokenStream.close();
}
}
return result.toArray(new String[result.size()]);

View File

@ -21,11 +21,21 @@ import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.StoredDocument;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
/**
@ -49,37 +59,67 @@ public class ConfusionMatrixGenerator {
* @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
* @throws IOException if problems occurr while reading the index or using the classifier
*/
public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName, String textFieldName) throws IOException {
public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName,
String textFieldName) throws IOException {
Map<String, Map<String, Long>> counts = new HashMap<>();
ExecutorService executorService = Executors.newFixedThreadPool(1);
for (int i = 0; i < reader.maxDoc(); i++) {
StoredDocument doc = reader.document(i);
String correctAnswer = doc.get(classFieldName);
try {
if (correctAnswer != null && correctAnswer.length() > 0) {
Map<String, Map<String, Long>> counts = new HashMap<>();
IndexSearcher indexSearcher = new IndexSearcher(reader);
TopDocs topDocs = indexSearcher.search(new WildcardQuery(new Term(classFieldName, "*")), Integer.MAX_VALUE);
double time = 0d;
ClassificationResult<T> result = classifier.assignClass(doc.get(textFieldName));
T assignedClass = result.getAssignedClass();
String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
StoredDocument doc = reader.document(scoreDoc.doc);
String correctAnswer = doc.get(classFieldName);
if (correctAnswer != null && correctAnswer.length() > 0) {
ClassificationResult<T> result;
String text = doc.get(textFieldName);
if (text != null) {
try {
// fail if classification takes more than 5s
long start = System.currentTimeMillis();
result = executorService.submit(() -> classifier.assignClass(text)).get(5, TimeUnit.SECONDS);
long end = System.currentTimeMillis();
time += end - start;
if (result != null) {
T assignedClass = result.getAssignedClass();
if (assignedClass != null) {
String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
Map<String, Long> stringLongMap = counts.get(correctAnswer);
if (stringLongMap != null) {
Long aLong = stringLongMap.get(classified);
if (aLong != null) {
stringLongMap.put(classified, aLong + 1);
} else {
stringLongMap.put(classified, 1l);
}
} else {
stringLongMap = new HashMap<>();
stringLongMap.put(classified, 1l);
counts.put(correctAnswer, stringLongMap);
}
}
}
} catch (TimeoutException timeoutException) {
// add timeout
time += 5000;
} catch (ExecutionException | InterruptedException executionException) {
throw new RuntimeException(executionException);
}
Map<String, Long> stringLongMap = counts.get(correctAnswer);
if (stringLongMap != null) {
Long aLong = stringLongMap.get(classified);
if (aLong != null) {
stringLongMap.put(classified, aLong + 1);
} else {
stringLongMap.put(classified, 1l);
}
} else {
stringLongMap = new HashMap<>();
stringLongMap.put(classified, 1l);
counts.put(correctAnswer, stringLongMap);
}
}
return new ConfusionMatrix(counts, time / topDocs.totalHits, topDocs.totalHits);
} finally {
executorService.shutdown();
}
return new ConfusionMatrix(counts);
}
/**
@ -88,9 +128,13 @@ public class ConfusionMatrixGenerator {
public static class ConfusionMatrix {
private final Map<String, Map<String, Long>> linearizedMatrix;
private final double avgClassificationTime;
private final int numberOfEvaluatedDocs;
private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix) {
private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix, double avgClassificationTime, int numberOfEvaluatedDocs) {
this.linearizedMatrix = linearizedMatrix;
this.avgClassificationTime = avgClassificationTime;
this.numberOfEvaluatedDocs = numberOfEvaluatedDocs;
}
/**
@ -104,8 +148,26 @@ public class ConfusionMatrixGenerator {
@Override
public String toString() {
return "ConfusionMatrix{" +
"linearizedMatrix=" + linearizedMatrix +
'}';
"linearizedMatrix=" + linearizedMatrix +
", avgClassificationTime=" + avgClassificationTime +
", numberOfEvaluatedDocs=" + numberOfEvaluatedDocs +
'}';
}
/**
* get the average classification time in milliseconds
* @return the avg classification time
*/
public double getAvgClassificationTime() {
return avgClassificationTime;
}
/**
* get the no. of documents evaluated while generating this confusion matrix
* @return the no. of documents evaluated
*/
public int getNumberOfEvaluatedDocs() {
return numberOfEvaluatedDocs;
}
}
}

View File

@ -17,6 +17,7 @@
package org.apache.lucene.classification;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
@ -32,7 +33,7 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, null, booleanFieldName, textFieldName), TECHNOLOGY_INPUT, false);
} finally {
if (leafReader != null) {
@ -46,7 +47,7 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, 50d, booleanFieldName, textFieldName);
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, false);
checkCorrectClassification(classifier, POLITICS_INPUT, true);
@ -63,7 +64,7 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, analyzer, query, 1, null, booleanFieldName, textFieldName), TECHNOLOGY_INPUT, false);
} finally {
if (leafReader != null) {
@ -72,4 +73,29 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
}
}
@Test
public void testPerformance() throws Exception {
MockAnalyzer analyzer = new MockAnalyzer(random());
LeafReader leafReader = getRandomIndex(analyzer, 100);
try {
long trainStart = System.currentTimeMillis();
BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, 0d, booleanFieldName, textFieldName);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
long evaluationStart = System.currentTimeMillis();
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
classifier, categoryFieldName, textFieldName);
assertNotNull(confusionMatrix);
long evaluationEnd = System.currentTimeMillis();
long evaluationTime = evaluationEnd - evaluationStart;
assertTrue("evaluation took more than 1m: " + evaluationTime / 1000 + "s", evaluationTime < 60000);
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(5000 > avgClassificationTime);
} finally {
leafReader.close();
}
}
}

View File

@ -23,8 +23,8 @@ import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.core.KeywordTokenizer;
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
@ -40,7 +40,7 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
} finally {
@ -55,7 +55,7 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
@ -70,7 +70,7 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
LeafReader leafReader = null;
try {
NGramAnalyzer analyzer = new NGramAnalyzer();
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
@ -87,4 +87,31 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
}
}
@Test
public void testPerformance() throws Exception {
MockAnalyzer analyzer = new MockAnalyzer(random());
LeafReader leafReader = getRandomIndex(analyzer, 100);
try {
long trainStart = System.currentTimeMillis();
CachingNaiveBayesClassifier simpleNaiveBayesClassifier = new CachingNaiveBayesClassifier(leafReader,
analyzer, null, categoryFieldName, textFieldName);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
long evaluationStart = System.currentTimeMillis();
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
simpleNaiveBayesClassifier, categoryFieldName, textFieldName);
assertNotNull(confusionMatrix);
long evaluationEnd = System.currentTimeMillis();
long evaluationTime = evaluationEnd - evaluationStart;
assertTrue("evaluation took more than 1m: " + evaluationTime / 1000 + "s", evaluationTime < 60000);
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(5000 > avgClassificationTime);
} finally {
leafReader.close();
}
}
}

View File

@ -40,21 +40,21 @@ import org.junit.Before;
* Base class for testing {@link Classifier}s
*/
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
protected static final String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
protected static final BytesRef POLITICS_RESULT = new BytesRef("politics");
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
protected static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
" Truth is, Amazon may know more.";
public static final String STRONG_TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
protected static final String STRONG_TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
" Truth is, Amazon may know more. This technology observation is extracted from the internet.";
public static final String SUPER_STRONG_TECHNOLOGY_INPUT = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
protected static final String SUPER_STRONG_TECHNOLOGY_INPUT = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
"generally transfer or store huge volumes of personal data online. traveling seeks raises some questions Republican presidential. ";
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
protected static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
protected RandomIndexWriter indexWriter;
private Directory dir;
@ -101,7 +101,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
}
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
populateSampleIndex(analyzer);
getSampleIndex(analyzer);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
@ -115,7 +115,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
}
protected LeafReader populateSampleIndex(Analyzer analyzer) throws IOException {
protected LeafReader getSampleIndex(Analyzer analyzer) throws IOException {
indexWriter.close();
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
indexWriter.commit();
@ -193,34 +193,27 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
}
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
long trainStart = System.currentTimeMillis();
populatePerformanceIndex(analyzer);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
}
private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
protected LeafReader getRandomIndex(Analyzer analyzer, int size) throws IOException {
indexWriter.close();
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
indexWriter.deleteAll();
indexWriter.commit();
FieldType ft = new FieldType(TextField.TYPE_STORED);
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
int docs = 1000;
Random random = random();
for (int i = 0; i < docs; i++) {
for (int i = 0; i < size; i++) {
boolean b = random.nextBoolean();
Document doc = new Document();
doc.add(new Field(textFieldName, createRandomString(random), ft));
doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
doc.add(new Field(categoryFieldName, String.valueOf(random.nextInt(1000)), ft));
doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
indexWriter.addDocument(doc);
}
indexWriter.commit();
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
}
private String createRandomString(Random random) {

View File

@ -21,6 +21,7 @@ import java.util.List;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.en.EnglishAnalyzer;
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
@ -38,7 +39,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, new LMDirichletSimilarity(), analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
ClassificationResult<BytesRef> resultDS = checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
@ -63,7 +64,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
LeafReader leafReader = null;
try {
Analyzer analyzer = new EnglishAnalyzer();
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, null, analyzer, null, 6, 1, 1, categoryFieldName, textFieldName);
List<ClassificationResult<BytesRef>> classes = knnClassifier.getClasses(STRONG_TECHNOLOGY_INPUT);
assertTrue(classes.get(0).getScore() > classes.get(1).getScore());
@ -88,7 +89,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
LeafReader leafReader = null;
try {
Analyzer analyzer = new EnglishAnalyzer();
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, null,analyzer, null, 3, 1, 1, categoryFieldName, textFieldName);
List<ClassificationResult<BytesRef>> classes = knnClassifier.getClasses(SUPER_STRONG_TECHNOLOGY_INPUT);
assertTrue(classes.get(0).getScore() > classes.get(1).getScore());
@ -105,7 +106,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
@ -115,4 +116,30 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
}
}
@Test
public void testPerformance() throws Exception {
MockAnalyzer analyzer = new MockAnalyzer(random());
LeafReader leafReader = getRandomIndex(analyzer, 100);
try {
long trainStart = System.currentTimeMillis();
KNearestNeighborClassifier kNearestNeighborClassifier = new KNearestNeighborClassifier(leafReader, null,
analyzer, null, 1, 2, 2, categoryFieldName, textFieldName);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
long evaluationStart = System.currentTimeMillis();
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
kNearestNeighborClassifier, categoryFieldName, textFieldName);
assertNotNull(confusionMatrix);
long evaluationEnd = System.currentTimeMillis();
long evaluationTime = evaluationEnd - evaluationStart;
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(5000 > avgClassificationTime);
} finally {
leafReader.close();
}
}
}

View File

@ -22,11 +22,12 @@ import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.core.KeywordTokenizer;
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.junit.Ignore;
import org.junit.Test;
/**
@ -39,9 +40,10 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
leafReader = getSampleIndex(analyzer);
SimpleNaiveBayesClassifier classifier = new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName);
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(classifier, POLITICS_INPUT, POLITICS_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
@ -54,7 +56,7 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
@ -69,7 +71,7 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
LeafReader leafReader = null;
try {
Analyzer analyzer = new NGramAnalyzer();
leafReader = populateSampleIndex(analyzer);
leafReader = getSampleIndex(analyzer);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
@ -86,4 +88,32 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
}
}
@Ignore
@Test
public void testPerformance() throws Exception {
MockAnalyzer analyzer = new MockAnalyzer(random());
LeafReader leafReader = getRandomIndex(analyzer, 100);
try {
long trainStart = System.currentTimeMillis();
SimpleNaiveBayesClassifier simpleNaiveBayesClassifier = new SimpleNaiveBayesClassifier(leafReader,
analyzer, null, categoryFieldName, textFieldName);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
long evaluationStart = System.currentTimeMillis();
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
simpleNaiveBayesClassifier, categoryFieldName, textFieldName);
assertNotNull(confusionMatrix);
long evaluationEnd = System.currentTimeMillis();
long evaluationTime = evaluationEnd - evaluationStart;
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime);
} finally {
leafReader.close();
}
}
}

View File

@ -17,9 +17,13 @@ package org.apache.lucene.classification.utils;
* limitations under the License.
*/
import java.io.IOException;
import java.util.List;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.classification.BooleanPerceptronClassifier;
import org.apache.lucene.classification.CachingNaiveBayesClassifier;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.ClassificationTestBase;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.classification.KNearestNeighborClassifier;
@ -33,16 +37,53 @@ import org.junit.Test;
*/
public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object> {
@Test
public void testGetConfusionMatrix() throws Exception {
LeafReader reader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new Classifier<BytesRef>() {
@Override
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
return new ClassificationResult<>(new BytesRef(), 1 / (1 + Math.exp(-random().nextInt())));
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
return null;
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
return null;
}
};
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(avgClassificationTime >= 0d );
} finally {
if (reader != null) {
reader.close();
}
}
}
@Test
public void testGetConfusionMatrixWithSNB() throws Exception {
LeafReader reader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = populateSampleIndex(analyzer);
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
} finally {
if (reader != null) {
reader.close();
@ -55,11 +96,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
LeafReader reader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = populateSampleIndex(analyzer);
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
} finally {
if (reader != null) {
reader.close();
@ -72,11 +115,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
LeafReader reader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = populateSampleIndex(analyzer);
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
} finally {
if (reader != null) {
reader.close();
@ -89,11 +134,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
LeafReader reader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = populateSampleIndex(analyzer);
reader = getSampleIndex(analyzer);
Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, booleanFieldName, textFieldName);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
} finally {
if (reader != null) {
reader.close();