mirror of https://github.com/apache/lucene.git
LUCENE-6853 - renamed threshold to bias, initialize to avg tf
This commit is contained in:
parent
c05ab96dc4
commit
cbad533d7a
|
@ -58,7 +58,7 @@ import org.apache.lucene.util.fst.Util;
|
||||||
*/
|
*/
|
||||||
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
|
|
||||||
private final Double threshold;
|
private final Double bias;
|
||||||
private final Terms textTerms;
|
private final Terms textTerms;
|
||||||
private final Analyzer analyzer;
|
private final Analyzer analyzer;
|
||||||
private final String textFieldName;
|
private final String textFieldName;
|
||||||
|
@ -72,14 +72,14 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||||
* if all the indexed docs should be used
|
* if all the indexed docs should be used
|
||||||
* @param batchSize the size of the batch of docs to use for updating the perceptron weights
|
* @param batchSize the size of the batch of docs to use for updating the perceptron weights
|
||||||
* @param threshold the threshold used for class separation
|
* @param bias the bias used for class separation
|
||||||
* @param classFieldName the name of the field used as the output for the classifier
|
* @param classFieldName the name of the field used as the output for the classifier
|
||||||
* @param textFieldName the name of the field used as input for the classifier
|
* @param textFieldName the name of the field used as input for the classifier
|
||||||
* @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field
|
* @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field
|
||||||
* cannot be found
|
* cannot be found
|
||||||
*/
|
*/
|
||||||
public BooleanPerceptronClassifier(IndexReader indexReader, Analyzer analyzer, Query query, Integer batchSize,
|
public BooleanPerceptronClassifier(IndexReader indexReader, Analyzer analyzer, Query query, Integer batchSize,
|
||||||
Double threshold, String classFieldName, String textFieldName) throws IOException {
|
Double bias, String classFieldName, String textFieldName) throws IOException {
|
||||||
this.textTerms = MultiFields.getTerms(indexReader, textFieldName);
|
this.textTerms = MultiFields.getTerms(indexReader, textFieldName);
|
||||||
|
|
||||||
if (textTerms == null) {
|
if (textTerms == null) {
|
||||||
|
@ -89,18 +89,18 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
this.analyzer = analyzer;
|
this.analyzer = analyzer;
|
||||||
this.textFieldName = textFieldName;
|
this.textFieldName = textFieldName;
|
||||||
|
|
||||||
if (threshold == null || threshold == 0d) {
|
if (bias == null || bias == 0d) {
|
||||||
// automatic assign a threshold
|
// automatic assign the bias to be the average total term freq
|
||||||
long sumDocFreq = indexReader.getSumDocFreq(textFieldName);
|
double t = (double) indexReader.getSumTotalTermFreq(textFieldName) / (double) indexReader.getDocCount(textFieldName);
|
||||||
if (sumDocFreq != -1) {
|
if (t != -1) {
|
||||||
this.threshold = (double) sumDocFreq / 2d;
|
this.bias = t;
|
||||||
} else {
|
} else {
|
||||||
throw new IOException(
|
throw new IOException(
|
||||||
"threshold cannot be assigned since term vectors for field "
|
"bias cannot be assigned since term vectors for field "
|
||||||
+ textFieldName + " do not exist");
|
+ textFieldName + " do not exist");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
this.threshold = threshold;
|
this.bias = bias;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO : remove this map as soon as we have a writable FST
|
// TODO : remove this map as soon as we have a writable FST
|
||||||
|
@ -173,7 +173,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
// update weights
|
// update weights
|
||||||
Long previousValue = Util.get(fst, term);
|
Long previousValue = Util.get(fst, term);
|
||||||
String termString = term.utf8ToString();
|
String termString = term.utf8ToString();
|
||||||
weights.put(termString, previousValue == null ? 0 : previousValue + modifier * termFreqLocal);
|
weights.put(termString, previousValue == null ? 0 : Math.max(0, previousValue + modifier * termFreqLocal));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (updateFST) {
|
if (updateFST) {
|
||||||
|
@ -216,8 +216,8 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
tokenStream.end();
|
tokenStream.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
|
double score = 1 - Math.exp(-1 * Math.abs(bias - output.doubleValue()) / bias);
|
||||||
return new ClassificationResult<>(output >= threshold, score);
|
return new ClassificationResult<>(output >= bias, score);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -34,7 +34,9 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
||||||
try {
|
try {
|
||||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
leafReader = getSampleIndex(analyzer);
|
leafReader = getSampleIndex(analyzer);
|
||||||
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, null, booleanFieldName, textFieldName), TECHNOLOGY_INPUT, false);
|
BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, null, 1, null, booleanFieldName, textFieldName);
|
||||||
|
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, false);
|
||||||
|
checkCorrectClassification(classifier, POLITICS_INPUT, true);
|
||||||
} finally {
|
} finally {
|
||||||
if (leafReader != null) {
|
if (leafReader != null) {
|
||||||
leafReader.close();
|
leafReader.close();
|
||||||
|
@ -60,12 +62,14 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsageWithQuery() throws Exception {
|
public void testBasicUsageWithQuery() throws Exception {
|
||||||
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
TermQuery query = new TermQuery(new Term(textFieldName, "of"));
|
||||||
LeafReader leafReader = null;
|
LeafReader leafReader = null;
|
||||||
try {
|
try {
|
||||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
leafReader = getSampleIndex(analyzer);
|
leafReader = getSampleIndex(analyzer);
|
||||||
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, analyzer, query, 1, null, booleanFieldName, textFieldName), TECHNOLOGY_INPUT, false);
|
BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader, analyzer, query, 1, null, booleanFieldName, textFieldName);
|
||||||
|
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, false);
|
||||||
|
checkCorrectClassification(classifier, POLITICS_INPUT, true);
|
||||||
} finally {
|
} finally {
|
||||||
if (leafReader != null) {
|
if (leafReader != null) {
|
||||||
leafReader.close();
|
leafReader.close();
|
||||||
|
@ -94,8 +98,8 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
||||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||||
assertTrue(5000 > avgClassificationTime);
|
assertTrue(5000 > avgClassificationTime);
|
||||||
// accuracy check disabled until LUCENE-6853 is fixed
|
// accuracy check disabled until LUCENE-6853 is fixed
|
||||||
// double accuracy = confusionMatrix.getAccuracy();
|
double accuracy = confusionMatrix.getAccuracy();
|
||||||
// assertTrue(accuracy > 0d);
|
assertTrue(accuracy > 0d);
|
||||||
} finally {
|
} finally {
|
||||||
leafReader.close();
|
leafReader.close();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue