mirror of https://github.com/apache/lucene.git
LUCENE-7193 - add generic f1-measure metric to confusion matrix
This commit is contained in:
parent
b02b026b7d
commit
2507015f4c
|
@ -56,12 +56,13 @@ public class ConfusionMatrixGenerator {
|
|||
* @param classifier the {@link Classifier} whose confusion matrix has to be generated
|
||||
* @param classFieldName the name of the Lucene field used as the classifier's output
|
||||
* @param textFieldName the nome the Lucene field used as the classifier's input
|
||||
* @param timeoutMilliseconds timeout to wait before stopping creating the confusion matrix
|
||||
* @param <T> the return type of the {@link ClassificationResult} returned by the given {@link Classifier}
|
||||
* @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
|
||||
* @throws IOException if problems occurr while reading the index or using the classifier
|
||||
*/
|
||||
public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName,
|
||||
String textFieldName) throws IOException {
|
||||
String textFieldName, long timeoutMilliseconds) throws IOException {
|
||||
|
||||
ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-"));
|
||||
|
||||
|
@ -72,7 +73,13 @@ public class ConfusionMatrixGenerator {
|
|||
TopDocs topDocs = indexSearcher.search(new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE);
|
||||
double time = 0d;
|
||||
|
||||
int counter = 0;
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
|
||||
if (timeoutMilliseconds > 0 && time >= timeoutMilliseconds) {
|
||||
break;
|
||||
}
|
||||
|
||||
Document doc = reader.document(scoreDoc.doc);
|
||||
String[] correctAnswers = doc.getValues(classFieldName);
|
||||
|
||||
|
@ -91,6 +98,7 @@ public class ConfusionMatrixGenerator {
|
|||
if (result != null) {
|
||||
T assignedClass = result.getAssignedClass();
|
||||
if (assignedClass != null) {
|
||||
counter++;
|
||||
String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
|
||||
|
||||
String correctAnswer;
|
||||
|
@ -117,7 +125,7 @@ public class ConfusionMatrixGenerator {
|
|||
}
|
||||
}
|
||||
} catch (TimeoutException timeoutException) {
|
||||
// add timeout
|
||||
// add classification timeout
|
||||
time += 5000;
|
||||
} catch (ExecutionException | InterruptedException executionException) {
|
||||
throw new RuntimeException(executionException);
|
||||
|
@ -126,7 +134,7 @@ public class ConfusionMatrixGenerator {
|
|||
}
|
||||
}
|
||||
}
|
||||
return new ConfusionMatrix(counts, time / topDocs.totalHits, topDocs.totalHits);
|
||||
return new ConfusionMatrix(counts, time / counter, counter);
|
||||
} finally {
|
||||
executorService.shutdown();
|
||||
}
|
||||
|
@ -167,7 +175,7 @@ public class ConfusionMatrixGenerator {
|
|||
public double getPrecision(String klass) {
|
||||
Map<String, Long> classifications = linearizedMatrix.get(klass);
|
||||
double tp = 0;
|
||||
double fp = -1;
|
||||
double fp = 0;
|
||||
if (classifications != null) {
|
||||
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
||||
if (klass.equals(entry.getKey())) {
|
||||
|
@ -180,7 +188,7 @@ public class ConfusionMatrixGenerator {
|
|||
}
|
||||
}
|
||||
}
|
||||
return tp + fp > 0 ? tp / (tp + fp) : 0;
|
||||
return tp > 0 ? tp / (tp + fp) : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -217,6 +225,17 @@ public class ConfusionMatrixGenerator {
|
|||
return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the F-1 measure on this confusion matrix
|
||||
*
|
||||
* @return the F-1 measure
|
||||
*/
|
||||
public double getF1Measure() {
|
||||
double recall = getRecall();
|
||||
double precision = getPrecision();
|
||||
return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate accuracy on this confusion matrix using the formula:
|
||||
* {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)}
|
||||
|
|
|
@ -59,7 +59,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
return null;
|
||||
}
|
||||
};
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
|
@ -74,6 +75,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -88,7 +92,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
|
@ -102,6 +107,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -116,7 +124,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
|
@ -130,6 +139,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -144,7 +156,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
|
@ -158,6 +171,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -172,7 +188,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, booleanFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, booleanFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
|
@ -186,6 +203,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision("true") >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision("true") <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision("false") >= 0d);
|
||||
|
|
Loading…
Reference in New Issue