mirror of https://github.com/apache/lucene.git
Fix MLT like text with custom frequencies
When an analyzer with custom term frequencies is used with MLT like texts, the custom term frequencies are incorrectly omitted and a fixed frequency of 1 is used instead. This commit fixes the issue by using `TermFrequencyAttribute` to get the term frequencies instead of using fixed 1. Also adds test cases for them mentioned issue.
This commit is contained in:
parent
5ca0602d28
commit
351e21f620
|
@ -28,6 +28,7 @@ import java.util.Set;
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.analysis.TokenStream;
|
import org.apache.lucene.analysis.TokenStream;
|
||||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||||
|
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
import org.apache.lucene.index.Fields;
|
import org.apache.lucene.index.Fields;
|
||||||
|
@ -824,6 +825,7 @@ public final class MoreLikeThis {
|
||||||
int tokenCount = 0;
|
int tokenCount = 0;
|
||||||
// for every token
|
// for every token
|
||||||
CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
|
CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
|
||||||
|
TermFrequencyAttribute tfAtt = ts.addAttribute(TermFrequencyAttribute.class);
|
||||||
ts.reset();
|
ts.reset();
|
||||||
while (ts.incrementToken()) {
|
while (ts.incrementToken()) {
|
||||||
String word = termAtt.toString();
|
String word = termAtt.toString();
|
||||||
|
@ -838,9 +840,9 @@ public final class MoreLikeThis {
|
||||||
// increment frequency
|
// increment frequency
|
||||||
Int cnt = termFreqMap.get(word);
|
Int cnt = termFreqMap.get(word);
|
||||||
if (cnt == null) {
|
if (cnt == null) {
|
||||||
termFreqMap.put(word, new Int());
|
termFreqMap.put(word, new Int(tfAtt.getTermFrequency()));
|
||||||
} else {
|
} else {
|
||||||
cnt.x++;
|
cnt.x += tfAtt.getTermFrequency();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ts.end();
|
ts.end();
|
||||||
|
@ -982,7 +984,11 @@ public final class MoreLikeThis {
|
||||||
int x;
|
int x;
|
||||||
|
|
||||||
Int() {
|
Int() {
|
||||||
x = 1;
|
this(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Int(int initialValue) {
|
||||||
|
x = initialValue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,12 @@ import java.util.Map;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.analysis.MockAnalyzer;
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
|
import org.apache.lucene.analysis.MockTokenFilter;
|
||||||
import org.apache.lucene.analysis.MockTokenizer;
|
import org.apache.lucene.analysis.MockTokenizer;
|
||||||
|
import org.apache.lucene.analysis.TokenFilter;
|
||||||
|
import org.apache.lucene.analysis.TokenStream;
|
||||||
|
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||||
|
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.index.IndexReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
@ -41,6 +46,7 @@ import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.QueryUtils;
|
import org.apache.lucene.search.QueryUtils;
|
||||||
import org.apache.lucene.search.TermQuery;
|
import org.apache.lucene.search.TermQuery;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
|
|
||||||
import static org.hamcrest.core.Is.is;
|
import static org.hamcrest.core.Is.is;
|
||||||
|
@ -427,5 +433,69 @@ public class TestMoreLikeThis extends LuceneTestCase {
|
||||||
analyzer.close();
|
analyzer.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testCustomFrequecy() throws IOException {
|
||||||
|
// define an analyzer with delimited term frequency, e.g. "foo|2 bar|3"
|
||||||
|
Analyzer analyzer = new Analyzer() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TokenStreamComponents createComponents(String fieldName) {
|
||||||
|
MockTokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false, 100);
|
||||||
|
MockTokenFilter filt = new MockTokenFilter(tokenizer, MockTokenFilter.EMPTY_STOPSET);
|
||||||
|
return new TokenStreamComponents(tokenizer, addCustomTokenFilter(filt));
|
||||||
|
}
|
||||||
|
|
||||||
|
TokenStream addCustomTokenFilter(TokenStream input) {
|
||||||
|
return new TokenFilter(input) {
|
||||||
|
final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
|
||||||
|
final TermFrequencyAttribute tfAtt = addAttribute(TermFrequencyAttribute.class);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean incrementToken() throws IOException {
|
||||||
|
if (input.incrementToken()) {
|
||||||
|
final char[] buffer = termAtt.buffer();
|
||||||
|
final int length = termAtt.length();
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
if (buffer[i] == '|') {
|
||||||
|
termAtt.setLength(i);
|
||||||
|
i++;
|
||||||
|
tfAtt.setTermFrequency(ArrayUtil.parseInt(buffer, i, length - i));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
mlt.setAnalyzer(analyzer);
|
||||||
|
mlt.setFieldNames(new String[] {"text"});
|
||||||
|
mlt.setBoost(true);
|
||||||
|
|
||||||
|
final double boost10 = ((BooleanQuery) mlt.like("text", new StringReader("lucene|10 release|1")))
|
||||||
|
.clauses()
|
||||||
|
.stream()
|
||||||
|
.map(BooleanClause::getQuery)
|
||||||
|
.map(BoostQuery.class::cast)
|
||||||
|
.filter(x -> ((TermQuery) x.getQuery()).getTerm().text().equals("lucene"))
|
||||||
|
.mapToDouble(BoostQuery::getBoost)
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
final double boost1 = ((BooleanQuery) mlt.like("text", new StringReader("lucene|1 release|1")))
|
||||||
|
.clauses()
|
||||||
|
.stream()
|
||||||
|
.map(BooleanClause::getQuery)
|
||||||
|
.map(BoostQuery.class::cast)
|
||||||
|
.filter(x -> ((TermQuery) x.getQuery()).getTerm().text().equals("lucene"))
|
||||||
|
.mapToDouble(BoostQuery::getBoost)
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
// mlt should use the custom frequencies provided by the analyzer so "lucene|10" should be boosted more than "lucene|1"
|
||||||
|
assertTrue(String.format("%s should be grater than %s", boost10, boost1), boost10 > boost1);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: add tests for the MoreLikeThisQuery
|
// TODO: add tests for the MoreLikeThisQuery
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue