mirror of https://github.com/apache/lucene.git
LUCENE-5699 - normalized score for boolean perceptron classifier
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1638715 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
a323091dac
commit
4641ee4c5d
|
@ -16,6 +16,12 @@
|
|||
*/
|
||||
package org.apache.lucene.classification;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.SortedMap;
|
||||
import java.util.concurrent.ConcurrentSkipListMap;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
|
@ -40,12 +46,6 @@ import org.apache.lucene.util.fst.FST;
|
|||
import org.apache.lucene.util.fst.PositiveIntOutputs;
|
||||
import org.apache.lucene.util.fst.Util;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.SortedMap;
|
||||
import java.util.TreeMap;
|
||||
|
||||
/**
|
||||
* A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
|
||||
* <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
|
||||
|
@ -53,7 +53,7 @@ import java.util.TreeMap;
|
|||
* {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field
|
||||
* and a per document basis and then a corresponding
|
||||
* {@link org.apache.lucene.util.fst.FST} is used for class assignment.
|
||||
*
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||
|
@ -67,9 +67,8 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
|
||||
/**
|
||||
* Create a {@link BooleanPerceptronClassifier}
|
||||
*
|
||||
* @param threshold
|
||||
* the binary threshold for perceptron output evaluation
|
||||
*
|
||||
* @param threshold the binary threshold for perceptron output evaluation
|
||||
*/
|
||||
public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
|
||||
this.threshold = threshold;
|
||||
|
@ -98,7 +97,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
Long output = 0l;
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
||||
CharTermAttribute charTermAttribute = tokenStream
|
||||
.addAttribute(CharTermAttribute.class);
|
||||
.addAttribute(CharTermAttribute.class);
|
||||
tokenStream.reset();
|
||||
while (tokenStream.incrementToken()) {
|
||||
String s = charTermAttribute.toString();
|
||||
|
@ -110,7 +109,8 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
tokenStream.end();
|
||||
}
|
||||
|
||||
return new ClassificationResult<>(output >= threshold, output.doubleValue());
|
||||
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
|
||||
return new ClassificationResult<>(output >= threshold, score);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -127,7 +127,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String textFieldName,
|
||||
String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
|
||||
|
||||
if (textTerms == null) {
|
||||
|
@ -150,7 +150,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
}
|
||||
|
||||
// TODO : remove this map as soon as we have a writable FST
|
||||
SortedMap<String,Double> weights = new TreeMap<>();
|
||||
SortedMap<String, Double> weights = new ConcurrentSkipListMap<>();
|
||||
|
||||
TermsEnum reuse = textTerms.iterator(null);
|
||||
BytesRef textTerm;
|
||||
|
@ -177,10 +177,10 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
ClassificationResult<Boolean> classificationResult = assignClass(doc
|
||||
.getField(textFieldName).stringValue());
|
||||
Boolean assignedClass = classificationResult.getAssignedClass();
|
||||
|
||||
|
||||
// get the expected result
|
||||
StorableField field = doc.getField(classFieldName);
|
||||
|
||||
|
||||
Boolean correctClass = Boolean.valueOf(field.stringValue());
|
||||
long modifier = correctClass.compareTo(assignedClass);
|
||||
if (modifier != 0) {
|
||||
|
@ -198,8 +198,8 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
}
|
||||
|
||||
private TermsEnum updateWeights(LeafReader leafReader, TermsEnum reuse,
|
||||
int docId, Boolean assignedClass, SortedMap<String,Double> weights,
|
||||
double modifier, boolean updateFST) throws IOException {
|
||||
int docId, Boolean assignedClass, SortedMap<String, Double> weights,
|
||||
double modifier, boolean updateFST) throws IOException {
|
||||
TermsEnum cte = textTerms.iterator(reuse);
|
||||
|
||||
// get the doc term vectors
|
||||
|
@ -231,12 +231,12 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
return reuse;
|
||||
}
|
||||
|
||||
private void updateFST(SortedMap<String,Double> weights) throws IOException {
|
||||
private void updateFST(SortedMap<String, Double> weights) throws IOException {
|
||||
PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
|
||||
Builder<Long> fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
|
||||
BytesRefBuilder scratchBytes = new BytesRefBuilder();
|
||||
IntsRefBuilder scratchInts = new IntsRefBuilder();
|
||||
for (Map.Entry<String,Double> entry : weights.entrySet()) {
|
||||
for (Map.Entry<String, Double> entry : weights.entrySet()) {
|
||||
scratchBytes.copyChars(entry.getKey());
|
||||
fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
|
||||
.getValue().longValue());
|
||||
|
|
|
@ -91,7 +91,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||
assertNotNull(classificationResult.getAssignedClass());
|
||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
|
||||
double score = classificationResult.getScore();
|
||||
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
|
||||
} finally {
|
||||
if (leafReader != null)
|
||||
leafReader.close();
|
||||
|
@ -110,11 +111,12 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||
assertNotNull(classificationResult.getAssignedClass());
|
||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
|
||||
double score = classificationResult.getScore();
|
||||
assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
|
||||
updateSampleIndex(analyzer);
|
||||
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
|
||||
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
|
||||
assertEquals(Double.valueOf(classificationResult.getScore()), Double.valueOf(secondClassificationResult.getScore()));
|
||||
assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
|
||||
|
||||
} finally {
|
||||
if (leafReader != null)
|
||||
|
|
Loading…
Reference in New Issue