mirror of https://github.com/apache/lucene.git
LUCENE-6479 - added a raw accuracy calculation to confusion matrix, minor adjustments to splitter
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1710197 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
34c87c2c4a
commit
308d0101f3
|
@ -19,6 +19,7 @@ package org.apache.lucene.classification;
|
|||
import java.io.IOException;
|
||||
import java.io.StringReader;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
@ -233,4 +234,14 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
}
|
||||
return returnList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "KNearestNeighborClassifier{" +
|
||||
", textFieldNames=" + Arrays.toString(textFieldNames) +
|
||||
", classFieldName='" + classFieldName + '\'' +
|
||||
", k=" + k +
|
||||
", query=" + query +
|
||||
'}';
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@ package org.apache.lucene.classification.utils;
|
|||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -34,6 +35,7 @@ import org.apache.lucene.index.StoredDocument;
|
|||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TermRangeQuery;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
@ -69,14 +71,15 @@ public class ConfusionMatrixGenerator {
|
|||
|
||||
Map<String, Map<String, Long>> counts = new HashMap<>();
|
||||
IndexSearcher indexSearcher = new IndexSearcher(reader);
|
||||
TopDocs topDocs = indexSearcher.search(new WildcardQuery(new Term(classFieldName, "*")), Integer.MAX_VALUE);
|
||||
TopDocs topDocs = indexSearcher.search(new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE);
|
||||
double time = 0d;
|
||||
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
StoredDocument doc = reader.document(scoreDoc.doc);
|
||||
String correctAnswer = doc.get(classFieldName);
|
||||
String[] correctAnswers = doc.getValues(classFieldName);
|
||||
|
||||
if (correctAnswer != null && correctAnswer.length() > 0) {
|
||||
if (correctAnswers != null && correctAnswers.length > 0) {
|
||||
Arrays.sort(correctAnswers);
|
||||
ClassificationResult<T> result;
|
||||
String text = doc.get(textFieldName);
|
||||
if (text != null) {
|
||||
|
@ -92,6 +95,13 @@ public class ConfusionMatrixGenerator {
|
|||
if (assignedClass != null) {
|
||||
String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
|
||||
|
||||
String correctAnswer;
|
||||
if (Arrays.binarySearch(correctAnswers, classified) >= 0) {
|
||||
correctAnswer = classified;
|
||||
} else {
|
||||
correctAnswer = correctAnswers[0];
|
||||
}
|
||||
|
||||
Map<String, Long> stringLongMap = counts.get(correctAnswer);
|
||||
if (stringLongMap != null) {
|
||||
Long aLong = stringLongMap.get(classified);
|
||||
|
@ -105,6 +115,7 @@ public class ConfusionMatrixGenerator {
|
|||
stringLongMap.put(classified, 1l);
|
||||
counts.put(correctAnswer, stringLongMap);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} catch (TimeoutException timeoutException) {
|
||||
|
@ -131,6 +142,7 @@ public class ConfusionMatrixGenerator {
|
|||
private final Map<String, Map<String, Long>> linearizedMatrix;
|
||||
private final double avgClassificationTime;
|
||||
private final int numberOfEvaluatedDocs;
|
||||
private double accuracy = -1d;
|
||||
|
||||
private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix, double avgClassificationTime, int numberOfEvaluatedDocs) {
|
||||
this.linearizedMatrix = linearizedMatrix;
|
||||
|
@ -140,12 +152,42 @@ public class ConfusionMatrixGenerator {
|
|||
|
||||
/**
|
||||
* get the linearized confusion matrix as a {@link Map}
|
||||
* @return a {@link Map} whose keys are the correct answers and whose values are the actual answers' counts
|
||||
* @return a {@link Map} whose keys are the correct classification answers and whose values are the actual answers'
|
||||
* counts
|
||||
*/
|
||||
public Map<String, Map<String, Long>> getLinearizedMatrix() {
|
||||
return Collections.unmodifiableMap(linearizedMatrix);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate accuracy on this confusion matrix using the formula:
|
||||
* {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)}
|
||||
*
|
||||
* @return the accuracy
|
||||
*/
|
||||
public double getAccuracy() {
|
||||
if (this.accuracy == -1) {
|
||||
double cc = 0d;
|
||||
double wc = 0d;
|
||||
for (Map.Entry<String, Map<String, Long>> entry : linearizedMatrix.entrySet()) {
|
||||
String correctAnswer = entry.getKey();
|
||||
for (Map.Entry<String, Long> classifiedAnswers : entry.getValue().entrySet()) {
|
||||
Long value = classifiedAnswers.getValue();
|
||||
if (value != null) {
|
||||
if (correctAnswer.equals(classifiedAnswers.getKey())) {
|
||||
cc += value;
|
||||
} else {
|
||||
wc += value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
this.accuracy = cc / (cc + wc);
|
||||
}
|
||||
return this.accuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ConfusionMatrix{" +
|
||||
|
|
|
@ -124,12 +124,18 @@ public class DatasetSplitter {
|
|||
}
|
||||
b++;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new IOException(e);
|
||||
} finally {
|
||||
// commit
|
||||
testWriter.commit();
|
||||
cvWriter.commit();
|
||||
trainingWriter.commit();
|
||||
|
||||
// merge
|
||||
testWriter.forceMerge(3);
|
||||
cvWriter.forceMerge(3);
|
||||
trainingWriter.forceMerge(3);
|
||||
} catch (Exception e) {
|
||||
throw new IOException(e);
|
||||
} finally {
|
||||
// close IWs
|
||||
testWriter.close();
|
||||
cvWriter.close();
|
||||
|
|
|
@ -79,7 +79,7 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
|||
LeafReader leafReader = getRandomIndex(analyzer, 100);
|
||||
try {
|
||||
long trainStart = System.currentTimeMillis();
|
||||
BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, 0d, booleanFieldName, textFieldName);
|
||||
BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, null, booleanFieldName, textFieldName);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
|
||||
|
@ -93,6 +93,8 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
|||
assertTrue("evaluation took more than 1m: " + evaluationTime / 1000 + "s", evaluationTime < 60000);
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(5000 > avgClassificationTime);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy > 0d);
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
|
|
|
@ -108,6 +108,8 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
|
|||
assertTrue("evaluation took more than 1m: " + evaluationTime / 1000 + "s", evaluationTime < 60000);
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(5000 > avgClassificationTime);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy > 0d);
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
|
|
|
@ -124,7 +124,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
try {
|
||||
long trainStart = System.currentTimeMillis();
|
||||
KNearestNeighborClassifier kNearestNeighborClassifier = new KNearestNeighborClassifier(leafReader, null,
|
||||
analyzer, null, 1, 2, 2, categoryFieldName, textFieldName);
|
||||
analyzer, null, 1, 1, 1, categoryFieldName, textFieldName);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
|
||||
|
@ -138,6 +138,8 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(5000 > avgClassificationTime);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy > 0d);
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
|
|
|
@ -109,6 +109,8 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
|||
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy > 0d);
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
|
|
|
@ -84,6 +84,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() > 0d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -103,6 +104,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() > 0d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -122,6 +124,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() > 0d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -141,6 +144,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() > 0d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
|
Loading…
Reference in New Issue