LUCENE-6479 - added a raw accuracy calculation to confusion matrix, minor adjustments to splitter

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1710197 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-10-23 12:21:50 +00:00
parent 34c87c2c4a
commit 308d0101f3
8 changed files with 80 additions and 9 deletions

View File

@ -19,6 +19,7 @@ package org.apache.lucene.classification;
import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@ -233,4 +234,14 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
}
return returnList;
}
@Override
public String toString() {
return "KNearestNeighborClassifier{" +
", textFieldNames=" + Arrays.toString(textFieldNames) +
", classFieldName='" + classFieldName + '\'' +
", k=" + k +
", query=" + query +
'}';
}
}

View File

@ -18,6 +18,7 @@ package org.apache.lucene.classification.utils;
*/
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@ -34,6 +35,7 @@ import org.apache.lucene.index.StoredDocument;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermRangeQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
@ -69,14 +71,15 @@ public class ConfusionMatrixGenerator {
Map<String, Map<String, Long>> counts = new HashMap<>();
IndexSearcher indexSearcher = new IndexSearcher(reader);
TopDocs topDocs = indexSearcher.search(new WildcardQuery(new Term(classFieldName, "*")), Integer.MAX_VALUE);
TopDocs topDocs = indexSearcher.search(new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE);
double time = 0d;
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
StoredDocument doc = reader.document(scoreDoc.doc);
String correctAnswer = doc.get(classFieldName);
String[] correctAnswers = doc.getValues(classFieldName);
if (correctAnswer != null && correctAnswer.length() > 0) {
if (correctAnswers != null && correctAnswers.length > 0) {
Arrays.sort(correctAnswers);
ClassificationResult<T> result;
String text = doc.get(textFieldName);
if (text != null) {
@ -92,6 +95,13 @@ public class ConfusionMatrixGenerator {
if (assignedClass != null) {
String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
String correctAnswer;
if (Arrays.binarySearch(correctAnswers, classified) >= 0) {
correctAnswer = classified;
} else {
correctAnswer = correctAnswers[0];
}
Map<String, Long> stringLongMap = counts.get(correctAnswer);
if (stringLongMap != null) {
Long aLong = stringLongMap.get(classified);
@ -105,6 +115,7 @@ public class ConfusionMatrixGenerator {
stringLongMap.put(classified, 1l);
counts.put(correctAnswer, stringLongMap);
}
}
}
} catch (TimeoutException timeoutException) {
@ -131,6 +142,7 @@ public class ConfusionMatrixGenerator {
private final Map<String, Map<String, Long>> linearizedMatrix;
private final double avgClassificationTime;
private final int numberOfEvaluatedDocs;
private double accuracy = -1d;
private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix, double avgClassificationTime, int numberOfEvaluatedDocs) {
this.linearizedMatrix = linearizedMatrix;
@ -140,12 +152,42 @@ public class ConfusionMatrixGenerator {
/**
* get the linearized confusion matrix as a {@link Map}
* @return a {@link Map} whose keys are the correct answers and whose values are the actual answers' counts
* @return a {@link Map} whose keys are the correct classification answers and whose values are the actual answers'
* counts
*/
public Map<String, Map<String, Long>> getLinearizedMatrix() {
return Collections.unmodifiableMap(linearizedMatrix);
}
/**
* Calculate accuracy on this confusion matrix using the formula:
* {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)}
*
* @return the accuracy
*/
public double getAccuracy() {
if (this.accuracy == -1) {
double cc = 0d;
double wc = 0d;
for (Map.Entry<String, Map<String, Long>> entry : linearizedMatrix.entrySet()) {
String correctAnswer = entry.getKey();
for (Map.Entry<String, Long> classifiedAnswers : entry.getValue().entrySet()) {
Long value = classifiedAnswers.getValue();
if (value != null) {
if (correctAnswer.equals(classifiedAnswers.getKey())) {
cc += value;
} else {
wc += value;
}
}
}
}
this.accuracy = cc / (cc + wc);
}
return this.accuracy;
}
@Override
public String toString() {
return "ConfusionMatrix{" +

View File

@ -124,12 +124,18 @@ public class DatasetSplitter {
}
b++;
}
} catch (Exception e) {
throw new IOException(e);
} finally {
// commit
testWriter.commit();
cvWriter.commit();
trainingWriter.commit();
// merge
testWriter.forceMerge(3);
cvWriter.forceMerge(3);
trainingWriter.forceMerge(3);
} catch (Exception e) {
throw new IOException(e);
} finally {
// close IWs
testWriter.close();
cvWriter.close();

View File

@ -79,7 +79,7 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
LeafReader leafReader = getRandomIndex(analyzer, 100);
try {
long trainStart = System.currentTimeMillis();
BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, 0d, booleanFieldName, textFieldName);
BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, null, booleanFieldName, textFieldName);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
@ -93,6 +93,8 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
assertTrue("evaluation took more than 1m: " + evaluationTime / 1000 + "s", evaluationTime < 60000);
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(5000 > avgClassificationTime);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy > 0d);
} finally {
leafReader.close();
}

View File

@ -108,6 +108,8 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
assertTrue("evaluation took more than 1m: " + evaluationTime / 1000 + "s", evaluationTime < 60000);
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(5000 > avgClassificationTime);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy > 0d);
} finally {
leafReader.close();
}

View File

@ -124,7 +124,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
try {
long trainStart = System.currentTimeMillis();
KNearestNeighborClassifier kNearestNeighborClassifier = new KNearestNeighborClassifier(leafReader, null,
analyzer, null, 1, 2, 2, categoryFieldName, textFieldName);
analyzer, null, 1, 1, 1, categoryFieldName, textFieldName);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
@ -138,6 +138,8 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(5000 > avgClassificationTime);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy > 0d);
} finally {
leafReader.close();
}

View File

@ -109,6 +109,8 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy > 0d);
} finally {
leafReader.close();
}

View File

@ -84,6 +84,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
assertTrue(confusionMatrix.getAccuracy() > 0d);
} finally {
if (reader != null) {
reader.close();
@ -103,6 +104,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
assertTrue(confusionMatrix.getAccuracy() > 0d);
} finally {
if (reader != null) {
reader.close();
@ -122,6 +124,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
assertTrue(confusionMatrix.getAccuracy() > 0d);
} finally {
if (reader != null) {
reader.close();
@ -141,6 +144,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() > 0d);
} finally {
if (reader != null) {
reader.close();