diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
index 497cc884104..7b35a069f2c 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
@@ -16,6 +16,12 @@
*/
package org.apache.lucene.classification;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.concurrent.ConcurrentSkipListMap;
+
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
@@ -40,12 +46,6 @@ import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;
-import java.io.IOException;
-import java.util.List;
-import java.util.Map;
-import java.util.SortedMap;
-import java.util.TreeMap;
-
/**
* A perceptron (see http://en.wikipedia.org/wiki/Perceptron
) based
* Boolean
{@link org.apache.lucene.classification.Classifier}. The
@@ -53,7 +53,7 @@ import java.util.TreeMap;
* {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field
* and a per document basis and then a corresponding
* {@link org.apache.lucene.util.fst.FST} is used for class assignment.
- *
+ *
* @lucene.experimental
*/
public class BooleanPerceptronClassifier implements Classifier {
@@ -67,9 +67,8 @@ public class BooleanPerceptronClassifier implements Classifier {
/**
* Create a {@link BooleanPerceptronClassifier}
- *
- * @param threshold
- * the binary threshold for perceptron output evaluation
+ *
+ * @param threshold the binary threshold for perceptron output evaluation
*/
public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
this.threshold = threshold;
@@ -98,7 +97,7 @@ public class BooleanPerceptronClassifier implements Classifier {
Long output = 0l;
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream
- .addAttribute(CharTermAttribute.class);
+ .addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
String s = charTermAttribute.toString();
@@ -110,7 +109,8 @@ public class BooleanPerceptronClassifier implements Classifier {
tokenStream.end();
}
- return new ClassificationResult<>(output >= threshold, output.doubleValue());
+ double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
+ return new ClassificationResult<>(output >= threshold, score);
}
/**
@@ -127,7 +127,7 @@ public class BooleanPerceptronClassifier implements Classifier {
*/
@Override
public void train(LeafReader leafReader, String textFieldName,
- String classFieldName, Analyzer analyzer, Query query) throws IOException {
+ String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
if (textTerms == null) {
@@ -150,7 +150,7 @@ public class BooleanPerceptronClassifier implements Classifier {
}
// TODO : remove this map as soon as we have a writable FST
- SortedMap weights = new TreeMap<>();
+ SortedMap weights = new ConcurrentSkipListMap<>();
TermsEnum reuse = textTerms.iterator(null);
BytesRef textTerm;
@@ -177,10 +177,10 @@ public class BooleanPerceptronClassifier implements Classifier {
ClassificationResult classificationResult = assignClass(doc
.getField(textFieldName).stringValue());
Boolean assignedClass = classificationResult.getAssignedClass();
-
+
// get the expected result
StorableField field = doc.getField(classFieldName);
-
+
Boolean correctClass = Boolean.valueOf(field.stringValue());
long modifier = correctClass.compareTo(assignedClass);
if (modifier != 0) {
@@ -198,8 +198,8 @@ public class BooleanPerceptronClassifier implements Classifier {
}
private TermsEnum updateWeights(LeafReader leafReader, TermsEnum reuse,
- int docId, Boolean assignedClass, SortedMap weights,
- double modifier, boolean updateFST) throws IOException {
+ int docId, Boolean assignedClass, SortedMap weights,
+ double modifier, boolean updateFST) throws IOException {
TermsEnum cte = textTerms.iterator(reuse);
// get the doc term vectors
@@ -231,12 +231,12 @@ public class BooleanPerceptronClassifier implements Classifier {
return reuse;
}
- private void updateFST(SortedMap weights) throws IOException {
+ private void updateFST(SortedMap weights) throws IOException {
PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
Builder fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
BytesRefBuilder scratchBytes = new BytesRefBuilder();
IntsRefBuilder scratchInts = new IntsRefBuilder();
- for (Map.Entry entry : weights.entrySet()) {
+ for (Map.Entry entry : weights.entrySet()) {
scratchBytes.copyChars(entry.getKey());
fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
.getValue().longValue());
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
index 179cf4ebc0a..3ff1f80d566 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
@@ -91,7 +91,8 @@ public abstract class ClassificationTestBase extends LuceneTestCase {
ClassificationResult classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
- assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
+ double score = classificationResult.getScore();
+ assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
} finally {
if (leafReader != null)
leafReader.close();
@@ -110,11 +111,12 @@ public abstract class ClassificationTestBase extends LuceneTestCase {
ClassificationResult classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
- assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
+ double score = classificationResult.getScore();
+ assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
updateSampleIndex(analyzer);
ClassificationResult secondClassificationResult = classifier.assignClass(inputDoc);
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
- assertEquals(Double.valueOf(classificationResult.getScore()), Double.valueOf(secondClassificationResult.getScore()));
+ assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
} finally {
if (leafReader != null)