mirror of https://github.com/apache/lucene.git
Fix MLT like text with custom frequencies
When an analyzer with custom term frequencies is used with MLT like texts, the custom term frequencies are incorrectly omitted and a fixed frequency of 1 is used instead. This commit fixes the issue by using `TermFrequencyAttribute` to get the term frequencies instead of using fixed 1. Also adds test cases for them mentioned issue.
This commit is contained in:
parent
412bb8077c
commit
2f1f9f8a36
|
@ -28,6 +28,7 @@ import java.util.Set;
|
|||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.Fields;
|
||||
|
@ -824,6 +825,7 @@ public final class MoreLikeThis {
|
|||
int tokenCount = 0;
|
||||
// for every token
|
||||
CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
|
||||
TermFrequencyAttribute tfAtt = ts.addAttribute(TermFrequencyAttribute.class);
|
||||
ts.reset();
|
||||
while (ts.incrementToken()) {
|
||||
String word = termAtt.toString();
|
||||
|
@ -838,9 +840,9 @@ public final class MoreLikeThis {
|
|||
// increment frequency
|
||||
Int cnt = termFreqMap.get(word);
|
||||
if (cnt == null) {
|
||||
termFreqMap.put(word, new Int());
|
||||
termFreqMap.put(word, new Int(tfAtt.getTermFrequency()));
|
||||
} else {
|
||||
cnt.x++;
|
||||
cnt.x += tfAtt.getTermFrequency();
|
||||
}
|
||||
}
|
||||
ts.end();
|
||||
|
@ -982,7 +984,11 @@ public final class MoreLikeThis {
|
|||
int x;
|
||||
|
||||
Int() {
|
||||
x = 1;
|
||||
this(1);
|
||||
}
|
||||
|
||||
Int(int initialValue) {
|
||||
x = initialValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,7 +27,12 @@ import java.util.Map;
|
|||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.analysis.MockTokenFilter;
|
||||
import org.apache.lucene.analysis.MockTokenizer;
|
||||
import org.apache.lucene.analysis.TokenFilter;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
|
@ -41,6 +46,7 @@ import org.apache.lucene.search.Query;
|
|||
import org.apache.lucene.search.QueryUtils;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
|
||||
import static org.hamcrest.core.Is.is;
|
||||
|
@ -427,5 +433,69 @@ public class TestMoreLikeThis extends LuceneTestCase {
|
|||
analyzer.close();
|
||||
}
|
||||
}
|
||||
|
||||
public void testCustomFrequecy() throws IOException {
|
||||
// define an analyzer with delimited term frequency, e.g. "foo|2 bar|3"
|
||||
Analyzer analyzer = new Analyzer() {
|
||||
|
||||
@Override
|
||||
protected TokenStreamComponents createComponents(String fieldName) {
|
||||
MockTokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false, 100);
|
||||
MockTokenFilter filt = new MockTokenFilter(tokenizer, MockTokenFilter.EMPTY_STOPSET);
|
||||
return new TokenStreamComponents(tokenizer, addCustomTokenFilter(filt));
|
||||
}
|
||||
|
||||
TokenStream addCustomTokenFilter(TokenStream input) {
|
||||
return new TokenFilter(input) {
|
||||
final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
|
||||
final TermFrequencyAttribute tfAtt = addAttribute(TermFrequencyAttribute.class);
|
||||
|
||||
@Override
|
||||
public boolean incrementToken() throws IOException {
|
||||
if (input.incrementToken()) {
|
||||
final char[] buffer = termAtt.buffer();
|
||||
final int length = termAtt.length();
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (buffer[i] == '|') {
|
||||
termAtt.setLength(i);
|
||||
i++;
|
||||
tfAtt.setTermFrequency(ArrayUtil.parseInt(buffer, i, length - i));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
mlt.setAnalyzer(analyzer);
|
||||
mlt.setFieldNames(new String[] {"text"});
|
||||
mlt.setBoost(true);
|
||||
|
||||
final double boost10 = ((BooleanQuery) mlt.like("text", new StringReader("lucene|10 release|1")))
|
||||
.clauses()
|
||||
.stream()
|
||||
.map(BooleanClause::getQuery)
|
||||
.map(BoostQuery.class::cast)
|
||||
.filter(x -> ((TermQuery) x.getQuery()).getTerm().text().equals("lucene"))
|
||||
.mapToDouble(BoostQuery::getBoost)
|
||||
.sum();
|
||||
|
||||
final double boost1 = ((BooleanQuery) mlt.like("text", new StringReader("lucene|1 release|1")))
|
||||
.clauses()
|
||||
.stream()
|
||||
.map(BooleanClause::getQuery)
|
||||
.map(BoostQuery.class::cast)
|
||||
.filter(x -> ((TermQuery) x.getQuery()).getTerm().text().equals("lucene"))
|
||||
.mapToDouble(BoostQuery::getBoost)
|
||||
.sum();
|
||||
|
||||
// mlt should use the custom frequencies provided by the analyzer so "lucene|10" should be boosted more than "lucene|1"
|
||||
assertTrue(String.format("%s should be grater than %s", boost10, boost1), boost10 > boost1);
|
||||
}
|
||||
|
||||
// TODO: add tests for the MoreLikeThisQuery
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue