diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java index 497cc884104..7b35a069f2c 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java @@ -16,6 +16,12 @@ */ package org.apache.lucene.classification; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.concurrent.ConcurrentSkipListMap; + import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; @@ -40,12 +46,6 @@ import org.apache.lucene.util.fst.FST; import org.apache.lucene.util.fst.PositiveIntOutputs; import org.apache.lucene.util.fst.Util; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.SortedMap; -import java.util.TreeMap; - /** * A perceptron (see http://en.wikipedia.org/wiki/Perceptron) based * Boolean {@link org.apache.lucene.classification.Classifier}. The @@ -53,7 +53,7 @@ import java.util.TreeMap; * {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field * and a per document basis and then a corresponding * {@link org.apache.lucene.util.fst.FST} is used for class assignment. - * + * * @lucene.experimental */ public class BooleanPerceptronClassifier implements Classifier { @@ -67,9 +67,8 @@ public class BooleanPerceptronClassifier implements Classifier { /** * Create a {@link BooleanPerceptronClassifier} - * - * @param threshold - * the binary threshold for perceptron output evaluation + * + * @param threshold the binary threshold for perceptron output evaluation */ public BooleanPerceptronClassifier(Double threshold, Integer batchSize) { this.threshold = threshold; @@ -98,7 +97,7 @@ public class BooleanPerceptronClassifier implements Classifier { Long output = 0l; try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) { CharTermAttribute charTermAttribute = tokenStream - .addAttribute(CharTermAttribute.class); + .addAttribute(CharTermAttribute.class); tokenStream.reset(); while (tokenStream.incrementToken()) { String s = charTermAttribute.toString(); @@ -110,7 +109,8 @@ public class BooleanPerceptronClassifier implements Classifier { tokenStream.end(); } - return new ClassificationResult<>(output >= threshold, output.doubleValue()); + double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold); + return new ClassificationResult<>(output >= threshold, score); } /** @@ -127,7 +127,7 @@ public class BooleanPerceptronClassifier implements Classifier { */ @Override public void train(LeafReader leafReader, String textFieldName, - String classFieldName, Analyzer analyzer, Query query) throws IOException { + String classFieldName, Analyzer analyzer, Query query) throws IOException { this.textTerms = MultiFields.getTerms(leafReader, textFieldName); if (textTerms == null) { @@ -150,7 +150,7 @@ public class BooleanPerceptronClassifier implements Classifier { } // TODO : remove this map as soon as we have a writable FST - SortedMap weights = new TreeMap<>(); + SortedMap weights = new ConcurrentSkipListMap<>(); TermsEnum reuse = textTerms.iterator(null); BytesRef textTerm; @@ -177,10 +177,10 @@ public class BooleanPerceptronClassifier implements Classifier { ClassificationResult classificationResult = assignClass(doc .getField(textFieldName).stringValue()); Boolean assignedClass = classificationResult.getAssignedClass(); - + // get the expected result StorableField field = doc.getField(classFieldName); - + Boolean correctClass = Boolean.valueOf(field.stringValue()); long modifier = correctClass.compareTo(assignedClass); if (modifier != 0) { @@ -198,8 +198,8 @@ public class BooleanPerceptronClassifier implements Classifier { } private TermsEnum updateWeights(LeafReader leafReader, TermsEnum reuse, - int docId, Boolean assignedClass, SortedMap weights, - double modifier, boolean updateFST) throws IOException { + int docId, Boolean assignedClass, SortedMap weights, + double modifier, boolean updateFST) throws IOException { TermsEnum cte = textTerms.iterator(reuse); // get the doc term vectors @@ -231,12 +231,12 @@ public class BooleanPerceptronClassifier implements Classifier { return reuse; } - private void updateFST(SortedMap weights) throws IOException { + private void updateFST(SortedMap weights) throws IOException { PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton(); Builder fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs); BytesRefBuilder scratchBytes = new BytesRefBuilder(); IntsRefBuilder scratchInts = new IntsRefBuilder(); - for (Map.Entry entry : weights.entrySet()) { + for (Map.Entry entry : weights.entrySet()) { scratchBytes.copyChars(entry.getKey()); fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry .getValue().longValue()); diff --git a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java index 179cf4ebc0a..3ff1f80d566 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java @@ -91,7 +91,8 @@ public abstract class ClassificationTestBase extends LuceneTestCase { ClassificationResult classificationResult = classifier.assignClass(inputDoc); assertNotNull(classificationResult.getAssignedClass()); assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); - assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0); + double score = classificationResult.getScore(); + assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0); } finally { if (leafReader != null) leafReader.close(); @@ -110,11 +111,12 @@ public abstract class ClassificationTestBase extends LuceneTestCase { ClassificationResult classificationResult = classifier.assignClass(inputDoc); assertNotNull(classificationResult.getAssignedClass()); assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); - assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0); + double score = classificationResult.getScore(); + assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0); updateSampleIndex(analyzer); ClassificationResult secondClassificationResult = classifier.assignClass(inputDoc); assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass()); - assertEquals(Double.valueOf(classificationResult.getScore()), Double.valueOf(secondClassificationResult.getScore())); + assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore())); } finally { if (leafReader != null)