Fix MLT like text with custom frequencies

When an analyzer with custom term frequencies is used with MLT like
texts, the custom term frequencies are incorrectly omitted and a fixed
frequency of 1 is used instead.

This commit fixes the issue by using `TermFrequencyAttribute` to get
the term frequencies instead of using fixed 1. Also adds test cases
for them mentioned issue.
This commit is contained in:
Olli Kuonanoja 2019-04-08 16:44:30 +03:00 committed by Mike McCandless
parent 5ca0602d28
commit 351e21f620
2 changed files with 79 additions and 3 deletions

View File

@ -28,6 +28,7 @@ import java.util.Set;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.Fields; import org.apache.lucene.index.Fields;
@ -824,6 +825,7 @@ public final class MoreLikeThis {
int tokenCount = 0; int tokenCount = 0;
// for every token // for every token
CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
TermFrequencyAttribute tfAtt = ts.addAttribute(TermFrequencyAttribute.class);
ts.reset(); ts.reset();
while (ts.incrementToken()) { while (ts.incrementToken()) {
String word = termAtt.toString(); String word = termAtt.toString();
@ -838,9 +840,9 @@ public final class MoreLikeThis {
// increment frequency // increment frequency
Int cnt = termFreqMap.get(word); Int cnt = termFreqMap.get(word);
if (cnt == null) { if (cnt == null) {
termFreqMap.put(word, new Int()); termFreqMap.put(word, new Int(tfAtt.getTermFrequency()));
} else { } else {
cnt.x++; cnt.x += tfAtt.getTermFrequency();
} }
} }
ts.end(); ts.end();
@ -982,7 +984,11 @@ public final class MoreLikeThis {
int x; int x;
Int() { Int() {
x = 1; this(1);
}
Int(int initialValue) {
x = initialValue;
} }
} }
} }

View File

@ -27,7 +27,12 @@ import java.util.Map;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.MockTokenFilter;
import org.apache.lucene.analysis.MockTokenizer; import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
@ -41,6 +46,7 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryUtils; import org.apache.lucene.search.QueryUtils;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
import static org.hamcrest.core.Is.is; import static org.hamcrest.core.Is.is;
@ -427,5 +433,69 @@ public class TestMoreLikeThis extends LuceneTestCase {
analyzer.close(); analyzer.close();
} }
} }
public void testCustomFrequecy() throws IOException {
// define an analyzer with delimited term frequency, e.g. "foo|2 bar|3"
Analyzer analyzer = new Analyzer() {
@Override
protected TokenStreamComponents createComponents(String fieldName) {
MockTokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false, 100);
MockTokenFilter filt = new MockTokenFilter(tokenizer, MockTokenFilter.EMPTY_STOPSET);
return new TokenStreamComponents(tokenizer, addCustomTokenFilter(filt));
}
TokenStream addCustomTokenFilter(TokenStream input) {
return new TokenFilter(input) {
final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
final TermFrequencyAttribute tfAtt = addAttribute(TermFrequencyAttribute.class);
@Override
public boolean incrementToken() throws IOException {
if (input.incrementToken()) {
final char[] buffer = termAtt.buffer();
final int length = termAtt.length();
for (int i = 0; i < length; i++) {
if (buffer[i] == '|') {
termAtt.setLength(i);
i++;
tfAtt.setTermFrequency(ArrayUtil.parseInt(buffer, i, length - i));
return true;
}
}
return true;
}
return false;
}
};
}
};
mlt.setAnalyzer(analyzer);
mlt.setFieldNames(new String[] {"text"});
mlt.setBoost(true);
final double boost10 = ((BooleanQuery) mlt.like("text", new StringReader("lucene|10 release|1")))
.clauses()
.stream()
.map(BooleanClause::getQuery)
.map(BoostQuery.class::cast)
.filter(x -> ((TermQuery) x.getQuery()).getTerm().text().equals("lucene"))
.mapToDouble(BoostQuery::getBoost)
.sum();
final double boost1 = ((BooleanQuery) mlt.like("text", new StringReader("lucene|1 release|1")))
.clauses()
.stream()
.map(BooleanClause::getQuery)
.map(BoostQuery.class::cast)
.filter(x -> ((TermQuery) x.getQuery()).getTerm().text().equals("lucene"))
.mapToDouble(BoostQuery::getBoost)
.sum();
// mlt should use the custom frequencies provided by the analyzer so "lucene|10" should be boosted more than "lucene|1"
assertTrue(String.format("%s should be grater than %s", boost10, boost1), boost10 > boost1);
}
// TODO: add tests for the MoreLikeThisQuery // TODO: add tests for the MoreLikeThisQuery
} }