mirror of https://github.com/apache/lucene.git
LUCENE-10172 - minor java code improvements to Lucene Classification (#381)
* LUCENE-10172 - minor code improvements * LUCENE-10172 - spotlessApply
This commit is contained in:
parent
c36ce300ae
commit
cfd9f9f98f
|
@ -216,7 +216,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
|
||||
@Override
|
||||
public ClassificationResult<Boolean> assignClass(String text) throws IOException {
|
||||
Long output = 0L;
|
||||
long output = 0L;
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
||||
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
|
||||
tokenStream.reset();
|
||||
|
@ -230,17 +230,17 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
tokenStream.end();
|
||||
}
|
||||
|
||||
double score = 1 - Math.exp(-1 * Math.abs(bias - output.doubleValue()) / bias);
|
||||
double score = 1 - Math.exp(-1 * Math.abs(bias - (double) output) / bias);
|
||||
return new ClassificationResult<>(output >= bias, score);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ClassificationResult<Boolean>> getClasses(String text) throws IOException {
|
||||
public List<ClassificationResult<Boolean>> getClasses(String text) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ClassificationResult<Boolean>> getClasses(String text, int max) throws IOException {
|
||||
public List<ClassificationResult<Boolean>> getClasses(String text, int max) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.classification.utils.NearestFuzzyQuery;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
|
@ -93,11 +94,8 @@ public class KNearestFuzzyClassifier implements Classifier<BytesRef> {
|
|||
this.classFieldName = classFieldName;
|
||||
this.analyzer = analyzer;
|
||||
this.indexSearcher = new IndexSearcher(indexReader);
|
||||
if (similarity != null) {
|
||||
this.indexSearcher.setSimilarity(similarity);
|
||||
} else {
|
||||
this.indexSearcher.setSimilarity(new BM25Similarity());
|
||||
}
|
||||
this.indexSearcher.setSimilarity(
|
||||
Objects.requireNonNullElseGet(similarity, BM25Similarity::new));
|
||||
this.query = query;
|
||||
this.k = k;
|
||||
}
|
||||
|
@ -166,7 +164,7 @@ public class KNearestFuzzyClassifier implements Classifier<BytesRef> {
|
|||
if (storableField != null) {
|
||||
BytesRef cl = new BytesRef(storableField.stringValue());
|
||||
// update count
|
||||
classCounts.merge(cl, 1, (a, b) -> a + b);
|
||||
classCounts.merge(cl, 1, Integer::sum);
|
||||
// update boost, the boost is based on the best score
|
||||
Double totalBoost = classBoosts.get(cl);
|
||||
double singleBoost = scoreDoc.score / maxScore;
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexableField;
|
||||
|
@ -102,11 +103,8 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
this.mlt.setAnalyzer(analyzer);
|
||||
this.mlt.setFieldNames(textFieldNames);
|
||||
this.indexSearcher = new IndexSearcher(indexReader);
|
||||
if (similarity != null) {
|
||||
this.indexSearcher.setSimilarity(similarity);
|
||||
} else {
|
||||
this.indexSearcher.setSimilarity(new BM25Similarity());
|
||||
}
|
||||
this.indexSearcher.setSimilarity(
|
||||
Objects.requireNonNullElseGet(similarity, BM25Similarity::new));
|
||||
if (minDocsFreq > 0) {
|
||||
mlt.setMinDocFreq(minDocsFreq);
|
||||
}
|
||||
|
@ -199,7 +197,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
if (singleStorableField != null) {
|
||||
BytesRef cl = new BytesRef(singleStorableField.stringValue());
|
||||
// update count
|
||||
classCounts.merge(cl, 1, (a, b) -> a + b);
|
||||
classCounts.merge(cl, 1, Integer::sum);
|
||||
// update boost, the boost is based on the best score
|
||||
Double totalBoost = classBoosts.get(cl);
|
||||
double singleBoost = scoreDoc.score / maxScore;
|
||||
|
|
|
@ -282,7 +282,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
}
|
||||
|
||||
private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
|
||||
return Math.log((double) docCount(term)) - Math.log(docsWithClassSize);
|
||||
return Math.log(docCount(term)) - Math.log(docsWithClassSize);
|
||||
}
|
||||
|
||||
private int docCount(Term term) throws IOException {
|
||||
|
|
|
@ -269,7 +269,7 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
|||
}
|
||||
|
||||
private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
|
||||
return Math.log((double) docCount(term)) - Math.log(docsWithClassSize);
|
||||
return Math.log(docCount(term)) - Math.log(docsWithClassSize);
|
||||
}
|
||||
|
||||
private int docCount(Term term) throws IOException {
|
||||
|
|
Loading…
Reference in New Issue