mirror of https://github.com/apache/lucene.git
LUCENE-6479 - improved cm testing, added stats, minor fixes
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1700914 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
e16e914057
commit
8b2e0d937d
|
@ -173,7 +173,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
// update weights
|
||||
Long previousValue = Util.get(fst, term);
|
||||
String termString = term.utf8ToString();
|
||||
weights.put(termString, previousValue + modifier * termFreqLocal);
|
||||
weights.put(termString, previousValue == null ? 0 : previousValue + modifier * termFreqLocal);
|
||||
}
|
||||
}
|
||||
if (updateFST) {
|
||||
|
@ -214,6 +214,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
}
|
||||
}
|
||||
tokenStream.end();
|
||||
tokenStream.close();
|
||||
}
|
||||
|
||||
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
|
||||
|
|
|
@ -80,7 +80,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
|||
}
|
||||
|
||||
|
||||
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||
String[] tokenizedDoc = tokenizeDoc(inputDocument);
|
||||
|
||||
List<ClassificationResult<BytesRef>> dataList = calculateLogLikelihood(tokenizedDoc);
|
||||
|
@ -200,7 +200,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
|||
}
|
||||
}
|
||||
if (insertPoint != null) {
|
||||
// threadsafe and concurent write
|
||||
// threadsafe and concurrent write
|
||||
termCClassHitCache.put(word, searched);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -134,7 +134,13 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
return doclist.subList(0, max);
|
||||
}
|
||||
|
||||
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||
/**
|
||||
* Calculate probabilities for all classes for a given input text
|
||||
* @param inputDocument the input text as a {@code String}
|
||||
* @return a {@code List} of {@code ClassificationResult}, one for each existing class
|
||||
* @throws IOException if assigning probabilities fails
|
||||
*/
|
||||
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();
|
||||
|
||||
Terms terms = MultiFields.getTerms(leafReader, classFieldName);
|
||||
|
@ -143,8 +149,10 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
String[] tokenizedDoc = tokenizeDoc(inputDocument);
|
||||
int docsWithClassSize = countDocsWithClass();
|
||||
while ((next = termsEnum.next()) != null) {
|
||||
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
|
||||
dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal));
|
||||
if (next.length > 0) {
|
||||
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
|
||||
dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal));
|
||||
}
|
||||
}
|
||||
|
||||
// normalization; the values transforms to a 0-1 range
|
||||
|
@ -212,6 +220,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
result.add(charTermAttribute.toString());
|
||||
}
|
||||
tokenStream.end();
|
||||
tokenStream.close();
|
||||
}
|
||||
}
|
||||
return result.toArray(new String[result.size()]);
|
||||
|
|
|
@ -21,11 +21,21 @@ import java.io.IOException;
|
|||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.classification.Classifier;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.StoredDocument;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
|
@ -49,37 +59,67 @@ public class ConfusionMatrixGenerator {
|
|||
* @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
|
||||
* @throws IOException if problems occurr while reading the index or using the classifier
|
||||
*/
|
||||
public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName, String textFieldName) throws IOException {
|
||||
public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName,
|
||||
String textFieldName) throws IOException {
|
||||
|
||||
Map<String, Map<String, Long>> counts = new HashMap<>();
|
||||
ExecutorService executorService = Executors.newFixedThreadPool(1);
|
||||
|
||||
for (int i = 0; i < reader.maxDoc(); i++) {
|
||||
StoredDocument doc = reader.document(i);
|
||||
String correctAnswer = doc.get(classFieldName);
|
||||
try {
|
||||
|
||||
if (correctAnswer != null && correctAnswer.length() > 0) {
|
||||
Map<String, Map<String, Long>> counts = new HashMap<>();
|
||||
IndexSearcher indexSearcher = new IndexSearcher(reader);
|
||||
TopDocs topDocs = indexSearcher.search(new WildcardQuery(new Term(classFieldName, "*")), Integer.MAX_VALUE);
|
||||
double time = 0d;
|
||||
|
||||
ClassificationResult<T> result = classifier.assignClass(doc.get(textFieldName));
|
||||
T assignedClass = result.getAssignedClass();
|
||||
String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
StoredDocument doc = reader.document(scoreDoc.doc);
|
||||
String correctAnswer = doc.get(classFieldName);
|
||||
|
||||
if (correctAnswer != null && correctAnswer.length() > 0) {
|
||||
ClassificationResult<T> result;
|
||||
String text = doc.get(textFieldName);
|
||||
if (text != null) {
|
||||
try {
|
||||
// fail if classification takes more than 5s
|
||||
long start = System.currentTimeMillis();
|
||||
result = executorService.submit(() -> classifier.assignClass(text)).get(5, TimeUnit.SECONDS);
|
||||
long end = System.currentTimeMillis();
|
||||
time += end - start;
|
||||
|
||||
if (result != null) {
|
||||
T assignedClass = result.getAssignedClass();
|
||||
if (assignedClass != null) {
|
||||
String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
|
||||
|
||||
Map<String, Long> stringLongMap = counts.get(correctAnswer);
|
||||
if (stringLongMap != null) {
|
||||
Long aLong = stringLongMap.get(classified);
|
||||
if (aLong != null) {
|
||||
stringLongMap.put(classified, aLong + 1);
|
||||
} else {
|
||||
stringLongMap.put(classified, 1l);
|
||||
}
|
||||
} else {
|
||||
stringLongMap = new HashMap<>();
|
||||
stringLongMap.put(classified, 1l);
|
||||
counts.put(correctAnswer, stringLongMap);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (TimeoutException timeoutException) {
|
||||
// add timeout
|
||||
time += 5000;
|
||||
} catch (ExecutionException | InterruptedException executionException) {
|
||||
throw new RuntimeException(executionException);
|
||||
}
|
||||
|
||||
Map<String, Long> stringLongMap = counts.get(correctAnswer);
|
||||
if (stringLongMap != null) {
|
||||
Long aLong = stringLongMap.get(classified);
|
||||
if (aLong != null) {
|
||||
stringLongMap.put(classified, aLong + 1);
|
||||
} else {
|
||||
stringLongMap.put(classified, 1l);
|
||||
}
|
||||
} else {
|
||||
stringLongMap = new HashMap<>();
|
||||
stringLongMap.put(classified, 1l);
|
||||
counts.put(correctAnswer, stringLongMap);
|
||||
}
|
||||
|
||||
}
|
||||
return new ConfusionMatrix(counts, time / topDocs.totalHits, topDocs.totalHits);
|
||||
} finally {
|
||||
executorService.shutdown();
|
||||
}
|
||||
return new ConfusionMatrix(counts);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -88,9 +128,13 @@ public class ConfusionMatrixGenerator {
|
|||
public static class ConfusionMatrix {
|
||||
|
||||
private final Map<String, Map<String, Long>> linearizedMatrix;
|
||||
private final double avgClassificationTime;
|
||||
private final int numberOfEvaluatedDocs;
|
||||
|
||||
private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix) {
|
||||
private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix, double avgClassificationTime, int numberOfEvaluatedDocs) {
|
||||
this.linearizedMatrix = linearizedMatrix;
|
||||
this.avgClassificationTime = avgClassificationTime;
|
||||
this.numberOfEvaluatedDocs = numberOfEvaluatedDocs;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -104,8 +148,26 @@ public class ConfusionMatrixGenerator {
|
|||
@Override
|
||||
public String toString() {
|
||||
return "ConfusionMatrix{" +
|
||||
"linearizedMatrix=" + linearizedMatrix +
|
||||
'}';
|
||||
"linearizedMatrix=" + linearizedMatrix +
|
||||
", avgClassificationTime=" + avgClassificationTime +
|
||||
", numberOfEvaluatedDocs=" + numberOfEvaluatedDocs +
|
||||
'}';
|
||||
}
|
||||
|
||||
/**
|
||||
* get the average classification time in milliseconds
|
||||
* @return the avg classification time
|
||||
*/
|
||||
public double getAvgClassificationTime() {
|
||||
return avgClassificationTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the no. of documents evaluated while generating this confusion matrix
|
||||
* @return the no. of documents evaluated
|
||||
*/
|
||||
public int getNumberOfEvaluatedDocs() {
|
||||
return numberOfEvaluatedDocs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.apache.lucene.classification;
|
||||
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
|
@ -32,7 +33,7 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, null, booleanFieldName, textFieldName), TECHNOLOGY_INPUT, false);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
|
@ -46,7 +47,7 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, 50d, booleanFieldName, textFieldName);
|
||||
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, false);
|
||||
checkCorrectClassification(classifier, POLITICS_INPUT, true);
|
||||
|
@ -63,7 +64,7 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, analyzer, query, 1, null, booleanFieldName, textFieldName), TECHNOLOGY_INPUT, false);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
|
@ -72,4 +73,29 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
LeafReader leafReader = getRandomIndex(analyzer, 100);
|
||||
try {
|
||||
long trainStart = System.currentTimeMillis();
|
||||
BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, 0d, booleanFieldName, textFieldName);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
|
||||
|
||||
long evaluationStart = System.currentTimeMillis();
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
|
||||
classifier, categoryFieldName, textFieldName);
|
||||
assertNotNull(confusionMatrix);
|
||||
long evaluationEnd = System.currentTimeMillis();
|
||||
long evaluationTime = evaluationEnd - evaluationStart;
|
||||
assertTrue("evaluation took more than 1m: " + evaluationTime / 1000 + "s", evaluationTime < 60000);
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(5000 > avgClassificationTime);
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -23,8 +23,8 @@ import org.apache.lucene.analysis.Tokenizer;
|
|||
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
||||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
@ -40,7 +40,7 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
|
||||
} finally {
|
||||
|
@ -55,7 +55,7 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
|
@ -70,7 +70,7 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
NGramAnalyzer analyzer = new NGramAnalyzer();
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
|
@ -87,4 +87,31 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
LeafReader leafReader = getRandomIndex(analyzer, 100);
|
||||
try {
|
||||
long trainStart = System.currentTimeMillis();
|
||||
CachingNaiveBayesClassifier simpleNaiveBayesClassifier = new CachingNaiveBayesClassifier(leafReader,
|
||||
analyzer, null, categoryFieldName, textFieldName);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
|
||||
|
||||
long evaluationStart = System.currentTimeMillis();
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
|
||||
simpleNaiveBayesClassifier, categoryFieldName, textFieldName);
|
||||
assertNotNull(confusionMatrix);
|
||||
long evaluationEnd = System.currentTimeMillis();
|
||||
long evaluationTime = evaluationEnd - evaluationStart;
|
||||
assertTrue("evaluation took more than 1m: " + evaluationTime / 1000 + "s", evaluationTime < 60000);
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(5000 > avgClassificationTime);
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -40,21 +40,21 @@ import org.junit.Before;
|
|||
* Base class for testing {@link Classifier}s
|
||||
*/
|
||||
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
|
||||
protected static final String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
|
||||
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
|
||||
public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
|
||||
protected static final BytesRef POLITICS_RESULT = new BytesRef("politics");
|
||||
|
||||
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
|
||||
protected static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
|
||||
" Truth is, Amazon may know more.";
|
||||
|
||||
public static final String STRONG_TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
|
||||
protected static final String STRONG_TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
|
||||
" Truth is, Amazon may know more. This technology observation is extracted from the internet.";
|
||||
|
||||
public static final String SUPER_STRONG_TECHNOLOGY_INPUT = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
|
||||
protected static final String SUPER_STRONG_TECHNOLOGY_INPUT = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
|
||||
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
|
||||
"generally transfer or store huge volumes of personal data online. traveling seeks raises some questions Republican presidential. ";
|
||||
|
||||
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
||||
protected static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
||||
|
||||
protected RandomIndexWriter indexWriter;
|
||||
private Directory dir;
|
||||
|
@ -101,7 +101,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
}
|
||||
|
||||
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
|
||||
populateSampleIndex(analyzer);
|
||||
getSampleIndex(analyzer);
|
||||
|
||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||
assertNotNull(classificationResult.getAssignedClass());
|
||||
|
@ -115,7 +115,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
|
||||
}
|
||||
|
||||
protected LeafReader populateSampleIndex(Analyzer analyzer) throws IOException {
|
||||
protected LeafReader getSampleIndex(Analyzer analyzer) throws IOException {
|
||||
indexWriter.close();
|
||||
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
|
||||
indexWriter.commit();
|
||||
|
@ -193,34 +193,27 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||
}
|
||||
|
||||
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
|
||||
long trainStart = System.currentTimeMillis();
|
||||
populatePerformanceIndex(analyzer);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
|
||||
}
|
||||
|
||||
private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
|
||||
protected LeafReader getRandomIndex(Analyzer analyzer, int size) throws IOException {
|
||||
indexWriter.close();
|
||||
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
|
||||
indexWriter.deleteAll();
|
||||
indexWriter.commit();
|
||||
|
||||
FieldType ft = new FieldType(TextField.TYPE_STORED);
|
||||
ft.setStoreTermVectors(true);
|
||||
ft.setStoreTermVectorOffsets(true);
|
||||
ft.setStoreTermVectorPositions(true);
|
||||
int docs = 1000;
|
||||
Random random = random();
|
||||
for (int i = 0; i < docs; i++) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
boolean b = random.nextBoolean();
|
||||
Document doc = new Document();
|
||||
doc.add(new Field(textFieldName, createRandomString(random), ft));
|
||||
doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
|
||||
doc.add(new Field(categoryFieldName, String.valueOf(random.nextInt(1000)), ft));
|
||||
doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
|
||||
indexWriter.addDocument(doc);
|
||||
}
|
||||
indexWriter.commit();
|
||||
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||
}
|
||||
|
||||
private String createRandomString(Random random) {
|
||||
|
|
|
@ -21,6 +21,7 @@ import java.util.List;
|
|||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.analysis.en.EnglishAnalyzer;
|
||||
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
|
@ -38,7 +39,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, new LMDirichletSimilarity(), analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
ClassificationResult<BytesRef> resultDS = checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
|
@ -63,7 +64,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
Analyzer analyzer = new EnglishAnalyzer();
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, null, analyzer, null, 6, 1, 1, categoryFieldName, textFieldName);
|
||||
List<ClassificationResult<BytesRef>> classes = knnClassifier.getClasses(STRONG_TECHNOLOGY_INPUT);
|
||||
assertTrue(classes.get(0).getScore() > classes.get(1).getScore());
|
||||
|
@ -88,7 +89,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
Analyzer analyzer = new EnglishAnalyzer();
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader, null,analyzer, null, 3, 1, 1, categoryFieldName, textFieldName);
|
||||
List<ClassificationResult<BytesRef>> classes = knnClassifier.getClasses(SUPER_STRONG_TECHNOLOGY_INPUT);
|
||||
assertTrue(classes.get(0).getScore() > classes.get(1).getScore());
|
||||
|
@ -105,7 +106,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
|
@ -115,4 +116,30 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
LeafReader leafReader = getRandomIndex(analyzer, 100);
|
||||
try {
|
||||
long trainStart = System.currentTimeMillis();
|
||||
KNearestNeighborClassifier kNearestNeighborClassifier = new KNearestNeighborClassifier(leafReader, null,
|
||||
analyzer, null, 1, 2, 2, categoryFieldName, textFieldName);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
|
||||
|
||||
long evaluationStart = System.currentTimeMillis();
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
|
||||
kNearestNeighborClassifier, categoryFieldName, textFieldName);
|
||||
assertNotNull(confusionMatrix);
|
||||
long evaluationEnd = System.currentTimeMillis();
|
||||
long evaluationTime = evaluationEnd - evaluationStart;
|
||||
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(5000 > avgClassificationTime);
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -22,11 +22,12 @@ import org.apache.lucene.analysis.Tokenizer;
|
|||
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
||||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
|
@ -39,9 +40,10 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
SimpleNaiveBayesClassifier classifier = new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName);
|
||||
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(classifier, POLITICS_INPUT, POLITICS_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
|
@ -54,7 +56,7 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
|
@ -69,7 +71,7 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
|||
LeafReader leafReader = null;
|
||||
try {
|
||||
Analyzer analyzer = new NGramAnalyzer();
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
|
@ -86,4 +88,32 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
|||
}
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
LeafReader leafReader = getRandomIndex(analyzer, 100);
|
||||
try {
|
||||
long trainStart = System.currentTimeMillis();
|
||||
SimpleNaiveBayesClassifier simpleNaiveBayesClassifier = new SimpleNaiveBayesClassifier(leafReader,
|
||||
analyzer, null, categoryFieldName, textFieldName);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
|
||||
|
||||
long evaluationStart = System.currentTimeMillis();
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
|
||||
simpleNaiveBayesClassifier, categoryFieldName, textFieldName);
|
||||
assertNotNull(confusionMatrix);
|
||||
long evaluationEnd = System.currentTimeMillis();
|
||||
long evaluationTime = evaluationEnd - evaluationStart;
|
||||
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime);
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,9 +17,13 @@ package org.apache.lucene.classification.utils;
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.classification.BooleanPerceptronClassifier;
|
||||
import org.apache.lucene.classification.CachingNaiveBayesClassifier;
|
||||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.classification.ClassificationTestBase;
|
||||
import org.apache.lucene.classification.Classifier;
|
||||
import org.apache.lucene.classification.KNearestNeighborClassifier;
|
||||
|
@ -33,16 +37,53 @@ import org.junit.Test;
|
|||
*/
|
||||
public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object> {
|
||||
|
||||
@Test
|
||||
public void testGetConfusionMatrix() throws Exception {
|
||||
LeafReader reader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new Classifier<BytesRef>() {
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
|
||||
return new ClassificationResult<>(new BytesRef(), 1 / (1 + Math.exp(-random().nextInt())));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(avgClassificationTime >= 0d );
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetConfusionMatrixWithSNB() throws Exception {
|
||||
LeafReader reader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = populateSampleIndex(analyzer);
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -55,11 +96,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
LeafReader reader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = populateSampleIndex(analyzer);
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -72,11 +115,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
LeafReader reader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = populateSampleIndex(analyzer);
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -89,11 +134,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
LeafReader reader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = populateSampleIndex(analyzer);
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, booleanFieldName, textFieldName);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
|
Loading…
Reference in New Issue