diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index f822b109fcd..61707228b88 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -19,6 +19,7 @@ package org.apache.lucene.classification; import java.io.IOException; import java.io.StringReader; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -233,4 +234,14 @@ public class KNearestNeighborClassifier implements Classifier { } return returnList; } + + @Override + public String toString() { + return "KNearestNeighborClassifier{" + + ", textFieldNames=" + Arrays.toString(textFieldNames) + + ", classFieldName='" + classFieldName + '\'' + + ", k=" + k + + ", query=" + query + + '}'; + } } \ No newline at end of file diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java index bef8449b38c..ce48d5c9108 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java @@ -18,6 +18,7 @@ package org.apache.lucene.classification.utils; */ import java.io.IOException; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -34,6 +35,7 @@ import org.apache.lucene.index.StoredDocument; import org.apache.lucene.index.Term; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TermRangeQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.WildcardQuery; import org.apache.lucene.util.BytesRef; @@ -69,14 +71,15 @@ public class ConfusionMatrixGenerator { Map> counts = new HashMap<>(); IndexSearcher indexSearcher = new IndexSearcher(reader); - TopDocs topDocs = indexSearcher.search(new WildcardQuery(new Term(classFieldName, "*")), Integer.MAX_VALUE); + TopDocs topDocs = indexSearcher.search(new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE); double time = 0d; for (ScoreDoc scoreDoc : topDocs.scoreDocs) { StoredDocument doc = reader.document(scoreDoc.doc); - String correctAnswer = doc.get(classFieldName); + String[] correctAnswers = doc.getValues(classFieldName); - if (correctAnswer != null && correctAnswer.length() > 0) { + if (correctAnswers != null && correctAnswers.length > 0) { + Arrays.sort(correctAnswers); ClassificationResult result; String text = doc.get(textFieldName); if (text != null) { @@ -92,6 +95,13 @@ public class ConfusionMatrixGenerator { if (assignedClass != null) { String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString(); + String correctAnswer; + if (Arrays.binarySearch(correctAnswers, classified) >= 0) { + correctAnswer = classified; + } else { + correctAnswer = correctAnswers[0]; + } + Map stringLongMap = counts.get(correctAnswer); if (stringLongMap != null) { Long aLong = stringLongMap.get(classified); @@ -105,6 +115,7 @@ public class ConfusionMatrixGenerator { stringLongMap.put(classified, 1l); counts.put(correctAnswer, stringLongMap); } + } } } catch (TimeoutException timeoutException) { @@ -131,6 +142,7 @@ public class ConfusionMatrixGenerator { private final Map> linearizedMatrix; private final double avgClassificationTime; private final int numberOfEvaluatedDocs; + private double accuracy = -1d; private ConfusionMatrix(Map> linearizedMatrix, double avgClassificationTime, int numberOfEvaluatedDocs) { this.linearizedMatrix = linearizedMatrix; @@ -140,12 +152,42 @@ public class ConfusionMatrixGenerator { /** * get the linearized confusion matrix as a {@link Map} - * @return a {@link Map} whose keys are the correct answers and whose values are the actual answers' counts + * @return a {@link Map} whose keys are the correct classification answers and whose values are the actual answers' + * counts */ public Map> getLinearizedMatrix() { return Collections.unmodifiableMap(linearizedMatrix); } + /** + * Calculate accuracy on this confusion matrix using the formula: + * {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)} + * + * @return the accuracy + */ + public double getAccuracy() { + if (this.accuracy == -1) { + double cc = 0d; + double wc = 0d; + for (Map.Entry> entry : linearizedMatrix.entrySet()) { + String correctAnswer = entry.getKey(); + for (Map.Entry classifiedAnswers : entry.getValue().entrySet()) { + Long value = classifiedAnswers.getValue(); + if (value != null) { + if (correctAnswer.equals(classifiedAnswers.getKey())) { + cc += value; + } else { + wc += value; + } + } + } + + } + this.accuracy = cc / (cc + wc); + } + return this.accuracy; + } + @Override public String toString() { return "ConfusionMatrix{" + diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java index 36a93324a9a..becff756fb8 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java @@ -124,12 +124,18 @@ public class DatasetSplitter { } b++; } - } catch (Exception e) { - throw new IOException(e); - } finally { + // commit testWriter.commit(); cvWriter.commit(); trainingWriter.commit(); + + // merge + testWriter.forceMerge(3); + cvWriter.forceMerge(3); + trainingWriter.forceMerge(3); + } catch (Exception e) { + throw new IOException(e); + } finally { // close IWs testWriter.close(); cvWriter.close(); diff --git a/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java index 014d5fa3516..4f81b2d2aa9 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java @@ -79,7 +79,7 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase avgClassificationTime); + double accuracy = confusionMatrix.getAccuracy(); + assertTrue(accuracy > 0d); } finally { leafReader.close(); } diff --git a/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java index e1a6cc188a0..6aaff59a221 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java @@ -108,6 +108,8 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase avgClassificationTime); + double accuracy = confusionMatrix.getAccuracy(); + assertTrue(accuracy > 0d); } finally { leafReader.close(); } diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java index f3252bfb7a6..e1d1b5a905d 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java @@ -124,7 +124,7 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase avgClassificationTime); + double accuracy = confusionMatrix.getAccuracy(); + assertTrue(accuracy > 0d); } finally { leafReader.close(); } diff --git a/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java index b8ac8510323..667c427a30a 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java @@ -109,6 +109,8 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase avgClassificationTime); + double accuracy = confusionMatrix.getAccuracy(); + assertTrue(accuracy > 0d); } finally { leafReader.close(); } diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java index fbf589976fc..23d08fbba86 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java @@ -84,6 +84,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertTrue(confusionMatrix.getAvgClassificationTime() > 0d); + assertTrue(confusionMatrix.getAccuracy() > 0d); } finally { if (reader != null) { reader.close(); @@ -103,6 +104,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertTrue(confusionMatrix.getAvgClassificationTime() > 0d); + assertTrue(confusionMatrix.getAccuracy() > 0d); } finally { if (reader != null) { reader.close(); @@ -122,6 +124,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertTrue(confusionMatrix.getAvgClassificationTime() > 0d); + assertTrue(confusionMatrix.getAccuracy() > 0d); } finally { if (reader != null) { reader.close(); @@ -141,6 +144,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d); + assertTrue(confusionMatrix.getAccuracy() > 0d); } finally { if (reader != null) { reader.close();