LUCENE-5699 - normalized score for boolean perceptron classifier

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1638715 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2014-11-12 08:38:06 +00:00
parent a323091dac
commit 4641ee4c5d
2 changed files with 25 additions and 23 deletions

View File

@ -16,6 +16,12 @@
*/
package org.apache.lucene.classification;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.concurrent.ConcurrentSkipListMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
@ -40,12 +46,6 @@ import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
/**
* A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
* <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
@ -53,7 +53,7 @@ import java.util.TreeMap;
* {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field
* and a per document basis and then a corresponding
* {@link org.apache.lucene.util.fst.FST} is used for class assignment.
*
*
* @lucene.experimental
*/
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
@ -67,9 +67,8 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
/**
* Create a {@link BooleanPerceptronClassifier}
*
* @param threshold
* the binary threshold for perceptron output evaluation
*
* @param threshold the binary threshold for perceptron output evaluation
*/
public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
this.threshold = threshold;
@ -98,7 +97,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
Long output = 0l;
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream
.addAttribute(CharTermAttribute.class);
.addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
String s = charTermAttribute.toString();
@ -110,7 +109,8 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
tokenStream.end();
}
return new ClassificationResult<>(output >= threshold, output.doubleValue());
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
return new ClassificationResult<>(output >= threshold, score);
}
/**
@ -127,7 +127,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
*/
@Override
public void train(LeafReader leafReader, String textFieldName,
String classFieldName, Analyzer analyzer, Query query) throws IOException {
String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
if (textTerms == null) {
@ -150,7 +150,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
}
// TODO : remove this map as soon as we have a writable FST
SortedMap<String,Double> weights = new TreeMap<>();
SortedMap<String, Double> weights = new ConcurrentSkipListMap<>();
TermsEnum reuse = textTerms.iterator(null);
BytesRef textTerm;
@ -177,10 +177,10 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
ClassificationResult<Boolean> classificationResult = assignClass(doc
.getField(textFieldName).stringValue());
Boolean assignedClass = classificationResult.getAssignedClass();
// get the expected result
StorableField field = doc.getField(classFieldName);
Boolean correctClass = Boolean.valueOf(field.stringValue());
long modifier = correctClass.compareTo(assignedClass);
if (modifier != 0) {
@ -198,8 +198,8 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
}
private TermsEnum updateWeights(LeafReader leafReader, TermsEnum reuse,
int docId, Boolean assignedClass, SortedMap<String,Double> weights,
double modifier, boolean updateFST) throws IOException {
int docId, Boolean assignedClass, SortedMap<String, Double> weights,
double modifier, boolean updateFST) throws IOException {
TermsEnum cte = textTerms.iterator(reuse);
// get the doc term vectors
@ -231,12 +231,12 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
return reuse;
}
private void updateFST(SortedMap<String,Double> weights) throws IOException {
private void updateFST(SortedMap<String, Double> weights) throws IOException {
PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
Builder<Long> fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
BytesRefBuilder scratchBytes = new BytesRefBuilder();
IntsRefBuilder scratchInts = new IntsRefBuilder();
for (Map.Entry<String,Double> entry : weights.entrySet()) {
for (Map.Entry<String, Double> entry : weights.entrySet()) {
scratchBytes.copyChars(entry.getKey());
fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
.getValue().longValue());

View File

@ -91,7 +91,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
double score = classificationResult.getScore();
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
} finally {
if (leafReader != null)
leafReader.close();
@ -110,11 +111,12 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
double score = classificationResult.getScore();
assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
updateSampleIndex(analyzer);
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
assertEquals(Double.valueOf(classificationResult.getScore()), Double.valueOf(secondClassificationResult.getScore()));
assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
} finally {
if (leafReader != null)