LUCENE-6853 - renamed threshold to bias, initialize to avg tf

This commit is contained in:
Tommaso Teofili 2017-04-06 19:05:52 +02:00
parent c05ab96dc4
commit cbad533d7a
2 changed files with 22 additions and 18 deletions

View File

@ -58,7 +58,7 @@ import org.apache.lucene.util.fst.Util;
*/ */
public class BooleanPerceptronClassifier implements Classifier<Boolean> { public class BooleanPerceptronClassifier implements Classifier<Boolean> {
private final Double threshold; private final Double bias;
private final Terms textTerms; private final Terms textTerms;
private final Analyzer analyzer; private final Analyzer analyzer;
private final String textFieldName; private final String textFieldName;
@ -72,14 +72,14 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null} * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used * if all the indexed docs should be used
* @param batchSize the size of the batch of docs to use for updating the perceptron weights * @param batchSize the size of the batch of docs to use for updating the perceptron weights
* @param threshold the threshold used for class separation * @param bias the bias used for class separation
* @param classFieldName the name of the field used as the output for the classifier * @param classFieldName the name of the field used as the output for the classifier
* @param textFieldName the name of the field used as input for the classifier * @param textFieldName the name of the field used as input for the classifier
* @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field * @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field
* cannot be found * cannot be found
*/ */
public BooleanPerceptronClassifier(IndexReader indexReader, Analyzer analyzer, Query query, Integer batchSize, public BooleanPerceptronClassifier(IndexReader indexReader, Analyzer analyzer, Query query, Integer batchSize,
Double threshold, String classFieldName, String textFieldName) throws IOException { Double bias, String classFieldName, String textFieldName) throws IOException {
this.textTerms = MultiFields.getTerms(indexReader, textFieldName); this.textTerms = MultiFields.getTerms(indexReader, textFieldName);
if (textTerms == null) { if (textTerms == null) {
@ -89,18 +89,18 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
this.analyzer = analyzer; this.analyzer = analyzer;
this.textFieldName = textFieldName; this.textFieldName = textFieldName;
if (threshold == null || threshold == 0d) { if (bias == null || bias == 0d) {
// automatic assign a threshold // automatic assign the bias to be the average total term freq
long sumDocFreq = indexReader.getSumDocFreq(textFieldName); double t = (double) indexReader.getSumTotalTermFreq(textFieldName) / (double) indexReader.getDocCount(textFieldName);
if (sumDocFreq != -1) { if (t != -1) {
this.threshold = (double) sumDocFreq / 2d; this.bias = t;
} else { } else {
throw new IOException( throw new IOException(
"threshold cannot be assigned since term vectors for field " "bias cannot be assigned since term vectors for field "
+ textFieldName + " do not exist"); + textFieldName + " do not exist");
} }
} else { } else {
this.threshold = threshold; this.bias = bias;
} }
// TODO : remove this map as soon as we have a writable FST // TODO : remove this map as soon as we have a writable FST
@ -173,7 +173,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
// update weights // update weights
Long previousValue = Util.get(fst, term); Long previousValue = Util.get(fst, term);
String termString = term.utf8ToString(); String termString = term.utf8ToString();
weights.put(termString, previousValue == null ? 0 : previousValue + modifier * termFreqLocal); weights.put(termString, previousValue == null ? 0 : Math.max(0, previousValue + modifier * termFreqLocal));
} }
} }
if (updateFST) { if (updateFST) {
@ -216,8 +216,8 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
tokenStream.end(); tokenStream.end();
} }
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold); double score = 1 - Math.exp(-1 * Math.abs(bias - output.doubleValue()) / bias);
return new ClassificationResult<>(output >= threshold, score); return new ClassificationResult<>(output >= bias, score);
} }
/** /**

View File

@ -34,7 +34,9 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
try { try {
MockAnalyzer analyzer = new MockAnalyzer(random()); MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = getSampleIndex(analyzer); leafReader = getSampleIndex(analyzer);
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, null, booleanFieldName, textFieldName), TECHNOLOGY_INPUT, false); BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, null, booleanFieldName, textFieldName);
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, false);
checkCorrectClassification(classifier, POLITICS_INPUT, true);
} finally { } finally {
if (leafReader != null) { if (leafReader != null) {
leafReader.close(); leafReader.close();
@ -60,12 +62,14 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
@Test @Test
public void testBasicUsageWithQuery() throws Exception { public void testBasicUsageWithQuery() throws Exception {
TermQuery query = new TermQuery(new Term(textFieldName, "it")); TermQuery query = new TermQuery(new Term(textFieldName, "of"));
LeafReader leafReader = null; LeafReader leafReader = null;
try { try {
MockAnalyzer analyzer = new MockAnalyzer(random()); MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = getSampleIndex(analyzer); leafReader = getSampleIndex(analyzer);
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, analyzer, query, 1, null, booleanFieldName, textFieldName), TECHNOLOGY_INPUT, false); BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, query, 1, null, booleanFieldName, textFieldName);
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, false);
checkCorrectClassification(classifier, POLITICS_INPUT, true);
} finally { } finally {
if (leafReader != null) { if (leafReader != null) {
leafReader.close(); leafReader.close();
@ -94,8 +98,8 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
double avgClassificationTime = confusionMatrix.getAvgClassificationTime(); double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(5000 > avgClassificationTime); assertTrue(5000 > avgClassificationTime);
// accuracy check disabled until LUCENE-6853 is fixed // accuracy check disabled until LUCENE-6853 is fixed
// double accuracy = confusionMatrix.getAccuracy(); double accuracy = confusionMatrix.getAccuracy();
// assertTrue(accuracy > 0d); assertTrue(accuracy > 0d);
} finally { } finally {
leafReader.close(); leafReader.close();
} }