diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java new file mode 100644 index 00000000000..21e95797838 --- /dev/null +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java @@ -0,0 +1,111 @@ +package org.apache.lucene.classification.utils; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.apache.lucene.classification.ClassificationResult; +import org.apache.lucene.classification.Classifier; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.StoredDocument; +import org.apache.lucene.util.BytesRef; + +/** + * Utility class to generate the confusion matrix of a {@link Classifier} + */ +public class ConfusionMatrixGenerator { + + private ConfusionMatrixGenerator() { + + } + + /** + * get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier}, + * generated on the given {@link LeafReader}, class and text fields. + * + * @param reader the {@link LeafReader} containing the index used for creating the {@link Classifier} + * @param classifier the {@link Classifier} whose confusion matrix has to be generated + * @param classFieldName the name of the Lucene field used as the classifier's output + * @param textFieldName the nome the Lucene field used as the classifier's input + * @param the return type of the {@link ClassificationResult} returned by the given {@link Classifier} + * @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} + * @throws IOException if problems occurr while reading the index or using the classifier + */ + public static ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier classifier, String classFieldName, String textFieldName) throws IOException { + + Map> counts = new HashMap<>(); + + for (int i = 0; i < reader.maxDoc(); i++) { + StoredDocument doc = reader.document(i); + String correctAnswer = doc.get(classFieldName); + + if (correctAnswer != null && correctAnswer.length() > 0) { + + ClassificationResult result = classifier.assignClass(doc.get(textFieldName)); + T assignedClass = result.getAssignedClass(); + String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString(); + + Map stringLongMap = counts.get(correctAnswer); + if (stringLongMap != null) { + Long aLong = stringLongMap.get(classified); + if (aLong != null) { + stringLongMap.put(classified, aLong + 1); + } else { + stringLongMap.put(classified, 1l); + } + } else { + stringLongMap = new HashMap<>(); + stringLongMap.put(classified, 1l); + counts.put(correctAnswer, stringLongMap); + } + + } + } + return new ConfusionMatrix(counts); + } + + /** + * a confusion matrix, backed by a {@link Map} representing the linearized matrix + */ + public static class ConfusionMatrix { + + private final Map> linearizedMatrix; + + private ConfusionMatrix(Map> linearizedMatrix) { + this.linearizedMatrix = linearizedMatrix; + } + + /** + * get the linearized confusion matrix as a {@link Map} + * @return a {@link Map} whose keys are the correct answers and whose values are the actual answers' counts + */ + public Map> getLinearizedMatrix() { + return Collections.unmodifiableMap(linearizedMatrix); + } + + @Override + public String toString() { + return "ConfusionMatrix{" + + "linearizedMatrix=" + linearizedMatrix + + '}'; + } + } +} diff --git a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java index 0822588cf5a..e09ced00156 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java @@ -52,9 +52,9 @@ public abstract class ClassificationTestBase extends LuceneTestCase { private Directory dir; private FieldType ft; - String textFieldName; - String categoryFieldName; - String booleanFieldName; + protected String textFieldName; + protected String categoryFieldName; + protected String booleanFieldName; @Override @Before diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java new file mode 100644 index 00000000000..1b7d21c4152 --- /dev/null +++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java @@ -0,0 +1,103 @@ +package org.apache.lucene.classification.utils; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.classification.BooleanPerceptronClassifier; +import org.apache.lucene.classification.CachingNaiveBayesClassifier; +import org.apache.lucene.classification.ClassificationTestBase; +import org.apache.lucene.classification.Classifier; +import org.apache.lucene.classification.KNearestNeighborClassifier; +import org.apache.lucene.classification.SimpleNaiveBayesClassifier; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.util.BytesRef; +import org.junit.Test; + +/** + * Tests for {@link ConfusionMatrixGenerator} + */ +public class ConfusionMatrixGeneratorTest extends ClassificationTestBase { + + @Test + public void testGetConfusionMatrixWithSNB() throws Exception { + LeafReader reader = null; + try { + MockAnalyzer analyzer = new MockAnalyzer(random()); + reader = populateSampleIndex(analyzer); + Classifier classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName); + ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName); + assertNotNull(confusionMatrix); + assertNotNull(confusionMatrix.getLinearizedMatrix()); + } finally { + if (reader != null) { + reader.close(); + } + } + } + + @Test + public void testGetConfusionMatrixWithCNB() throws Exception { + LeafReader reader = null; + try { + MockAnalyzer analyzer = new MockAnalyzer(random()); + reader = populateSampleIndex(analyzer); + Classifier classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName); + ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName); + assertNotNull(confusionMatrix); + assertNotNull(confusionMatrix.getLinearizedMatrix()); + } finally { + if (reader != null) { + reader.close(); + } + } + } + + @Test + public void testGetConfusionMatrixWithKNN() throws Exception { + LeafReader reader = null; + try { + MockAnalyzer analyzer = new MockAnalyzer(random()); + reader = populateSampleIndex(analyzer); + Classifier classifier = new KNearestNeighborClassifier(reader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName); + ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName); + assertNotNull(confusionMatrix); + assertNotNull(confusionMatrix.getLinearizedMatrix()); + } finally { + if (reader != null) { + reader.close(); + } + } + } + + @Test + public void testGetConfusionMatrixWithBP() throws Exception { + LeafReader reader = null; + try { + MockAnalyzer analyzer = new MockAnalyzer(random()); + reader = populateSampleIndex(analyzer); + Classifier classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName); + ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, booleanFieldName, textFieldName); + assertNotNull(confusionMatrix); + assertNotNull(confusionMatrix.getLinearizedMatrix()); + } finally { + if (reader != null) { + reader.close(); + } + } + } +} \ No newline at end of file