LUCENE-7193 - add generic f1-measure metric to confusion matrix

This commit is contained in:
Tommaso Teofili 2016-04-08 11:02:48 +02:00
parent b02b026b7d
commit 2507015f4c
2 changed files with 54 additions and 15 deletions

View File

@ -56,12 +56,13 @@ public class ConfusionMatrixGenerator {
* @param classifier the {@link Classifier} whose confusion matrix has to be generated
* @param classFieldName the name of the Lucene field used as the classifier's output
* @param textFieldName the nome the Lucene field used as the classifier's input
* @param timeoutMilliseconds timeout to wait before stopping creating the confusion matrix
* @param <T> the return type of the {@link ClassificationResult} returned by the given {@link Classifier}
* @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
* @throws IOException if problems occurr while reading the index or using the classifier
*/
public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName,
String textFieldName) throws IOException {
String textFieldName, long timeoutMilliseconds) throws IOException {
ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-"));
@ -72,7 +73,13 @@ public class ConfusionMatrixGenerator {
TopDocs topDocs = indexSearcher.search(new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE);
double time = 0d;
int counter = 0;
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
if (timeoutMilliseconds > 0 && time >= timeoutMilliseconds) {
break;
}
Document doc = reader.document(scoreDoc.doc);
String[] correctAnswers = doc.getValues(classFieldName);
@ -91,6 +98,7 @@ public class ConfusionMatrixGenerator {
if (result != null) {
T assignedClass = result.getAssignedClass();
if (assignedClass != null) {
counter++;
String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
String correctAnswer;
@ -117,7 +125,7 @@ public class ConfusionMatrixGenerator {
}
}
} catch (TimeoutException timeoutException) {
// add timeout
// add classification timeout
time += 5000;
} catch (ExecutionException | InterruptedException executionException) {
throw new RuntimeException(executionException);
@ -126,7 +134,7 @@ public class ConfusionMatrixGenerator {
}
}
}
return new ConfusionMatrix(counts, time / topDocs.totalHits, topDocs.totalHits);
return new ConfusionMatrix(counts, time / counter, counter);
} finally {
executorService.shutdown();
}
@ -167,7 +175,7 @@ public class ConfusionMatrixGenerator {
public double getPrecision(String klass) {
Map<String, Long> classifications = linearizedMatrix.get(klass);
double tp = 0;
double fp = -1;
double fp = 0;
if (classifications != null) {
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
if (klass.equals(entry.getKey())) {
@ -180,7 +188,7 @@ public class ConfusionMatrixGenerator {
}
}
}
return tp + fp > 0 ? tp / (tp + fp) : 0;
return tp > 0 ? tp / (tp + fp) : 0;
}
/**
@ -217,6 +225,17 @@ public class ConfusionMatrixGenerator {
return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
}
/**
* get the F-1 measure on this confusion matrix
*
* @return the F-1 measure
*/
public double getF1Measure() {
double recall = getRecall();
double precision = getPrecision();
return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
}
/**
* Calculate accuracy on this confusion matrix using the formula:
* {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)}

View File

@ -59,7 +59,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
return null;
}
};
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@ -74,6 +75,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -88,7 +92,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@ -102,6 +107,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -116,7 +124,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@ -130,6 +139,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -144,7 +156,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@ -158,6 +171,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -172,7 +188,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = getSampleIndex(analyzer);
Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, booleanFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, booleanFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@ -186,6 +203,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
assertTrue(confusionMatrix.getPrecision("true") >= 0d);
assertTrue(confusionMatrix.getPrecision("true") <= 1d);
assertTrue(confusionMatrix.getPrecision("false") >= 0d);