From cbad533d7a44a5fd41f85756c791f3d7439861a2 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Thu, 6 Apr 2017 19:05:52 +0200 Subject: [PATCH] LUCENE-6853 - renamed threshold to bias, initialize to avg tf --- .../BooleanPerceptronClassifier.java | 26 +++++++++---------- .../BooleanPerceptronClassifierTest.java | 14 ++++++---- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java index 781a14ff6ee..928c0366770 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java @@ -58,7 +58,7 @@ import org.apache.lucene.util.fst.Util; */ public class BooleanPerceptronClassifier implements Classifier { - private final Double threshold; + private final Double bias; private final Terms textTerms; private final Analyzer analyzer; private final String textFieldName; @@ -72,14 +72,14 @@ public class BooleanPerceptronClassifier implements Classifier { * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null} * if all the indexed docs should be used * @param batchSize the size of the batch of docs to use for updating the perceptron weights - * @param threshold the threshold used for class separation + * @param bias the bias used for class separation * @param classFieldName the name of the field used as the output for the classifier * @param textFieldName the name of the field used as input for the classifier * @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field * cannot be found */ public BooleanPerceptronClassifier(IndexReader indexReader, Analyzer analyzer, Query query, Integer batchSize, - Double threshold, String classFieldName, String textFieldName) throws IOException { + Double bias, String classFieldName, String textFieldName) throws IOException { this.textTerms = MultiFields.getTerms(indexReader, textFieldName); if (textTerms == null) { @@ -89,18 +89,18 @@ public class BooleanPerceptronClassifier implements Classifier { this.analyzer = analyzer; this.textFieldName = textFieldName; - if (threshold == null || threshold == 0d) { - // automatic assign a threshold - long sumDocFreq = indexReader.getSumDocFreq(textFieldName); - if (sumDocFreq != -1) { - this.threshold = (double) sumDocFreq / 2d; + if (bias == null || bias == 0d) { + // automatic assign the bias to be the average total term freq + double t = (double) indexReader.getSumTotalTermFreq(textFieldName) / (double) indexReader.getDocCount(textFieldName); + if (t != -1) { + this.bias = t; } else { throw new IOException( - "threshold cannot be assigned since term vectors for field " + "bias cannot be assigned since term vectors for field " + textFieldName + " do not exist"); } } else { - this.threshold = threshold; + this.bias = bias; } // TODO : remove this map as soon as we have a writable FST @@ -173,7 +173,7 @@ public class BooleanPerceptronClassifier implements Classifier { // update weights Long previousValue = Util.get(fst, term); String termString = term.utf8ToString(); - weights.put(termString, previousValue == null ? 0 : previousValue + modifier * termFreqLocal); + weights.put(termString, previousValue == null ? 0 : Math.max(0, previousValue + modifier * termFreqLocal)); } } if (updateFST) { @@ -216,8 +216,8 @@ public class BooleanPerceptronClassifier implements Classifier { tokenStream.end(); } - double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold); - return new ClassificationResult<>(output >= threshold, score); + double score = 1 - Math.exp(-1 * Math.abs(bias - output.doubleValue()) / bias); + return new ClassificationResult<>(output >= bias, score); } /** diff --git a/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java index 6ea92c03b08..ec059f75ac4 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java @@ -34,7 +34,9 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase avgClassificationTime); // accuracy check disabled until LUCENE-6853 is fixed -// double accuracy = confusionMatrix.getAccuracy(); -// assertTrue(accuracy > 0d); + double accuracy = confusionMatrix.getAccuracy(); + assertTrue(accuracy > 0d); } finally { leafReader.close(); }