diff --git a/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThis.java b/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThis.java index 61ebe937ee4..7c077e53fe2 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThis.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThis.java @@ -28,6 +28,7 @@ import java.util.Set; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute; import org.apache.lucene.document.Document; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.Fields; @@ -824,6 +825,7 @@ public final class MoreLikeThis { int tokenCount = 0; // for every token CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); + TermFrequencyAttribute tfAtt = ts.addAttribute(TermFrequencyAttribute.class); ts.reset(); while (ts.incrementToken()) { String word = termAtt.toString(); @@ -838,9 +840,9 @@ public final class MoreLikeThis { // increment frequency Int cnt = termFreqMap.get(word); if (cnt == null) { - termFreqMap.put(word, new Int()); + termFreqMap.put(word, new Int(tfAtt.getTermFrequency())); } else { - cnt.x++; + cnt.x += tfAtt.getTermFrequency(); } } ts.end(); @@ -982,7 +984,11 @@ public final class MoreLikeThis { int x; Int() { - x = 1; + this(1); + } + + Int(int initialValue) { + x = initialValue; } } } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/mlt/TestMoreLikeThis.java b/lucene/queries/src/test/org/apache/lucene/queries/mlt/TestMoreLikeThis.java index 4a60015c485..aeec5348c27 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/mlt/TestMoreLikeThis.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/mlt/TestMoreLikeThis.java @@ -27,7 +27,12 @@ import java.util.Map; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.analysis.MockTokenFilter; import org.apache.lucene.analysis.MockTokenizer; +import org.apache.lucene.analysis.TokenFilter; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.index.IndexReader; @@ -41,6 +46,7 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryUtils; import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; +import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.LuceneTestCase; import static org.hamcrest.core.Is.is; @@ -427,5 +433,69 @@ public class TestMoreLikeThis extends LuceneTestCase { analyzer.close(); } } + + public void testCustomFrequecy() throws IOException { + // define an analyzer with delimited term frequency, e.g. "foo|2 bar|3" + Analyzer analyzer = new Analyzer() { + + @Override + protected TokenStreamComponents createComponents(String fieldName) { + MockTokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false, 100); + MockTokenFilter filt = new MockTokenFilter(tokenizer, MockTokenFilter.EMPTY_STOPSET); + return new TokenStreamComponents(tokenizer, addCustomTokenFilter(filt)); + } + + TokenStream addCustomTokenFilter(TokenStream input) { + return new TokenFilter(input) { + final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); + final TermFrequencyAttribute tfAtt = addAttribute(TermFrequencyAttribute.class); + + @Override + public boolean incrementToken() throws IOException { + if (input.incrementToken()) { + final char[] buffer = termAtt.buffer(); + final int length = termAtt.length(); + for (int i = 0; i < length; i++) { + if (buffer[i] == '|') { + termAtt.setLength(i); + i++; + tfAtt.setTermFrequency(ArrayUtil.parseInt(buffer, i, length - i)); + return true; + } + } + return true; + } + return false; + } + }; + } + }; + + mlt.setAnalyzer(analyzer); + mlt.setFieldNames(new String[] {"text"}); + mlt.setBoost(true); + + final double boost10 = ((BooleanQuery) mlt.like("text", new StringReader("lucene|10 release|1"))) + .clauses() + .stream() + .map(BooleanClause::getQuery) + .map(BoostQuery.class::cast) + .filter(x -> ((TermQuery) x.getQuery()).getTerm().text().equals("lucene")) + .mapToDouble(BoostQuery::getBoost) + .sum(); + + final double boost1 = ((BooleanQuery) mlt.like("text", new StringReader("lucene|1 release|1"))) + .clauses() + .stream() + .map(BooleanClause::getQuery) + .map(BoostQuery.class::cast) + .filter(x -> ((TermQuery) x.getQuery()).getTerm().text().equals("lucene")) + .mapToDouble(BoostQuery::getBoost) + .sum(); + + // mlt should use the custom frequencies provided by the analyzer so "lucene|10" should be boosted more than "lucene|1" + assertTrue(String.format("%s should be grater than %s", boost10, boost1), boost10 > boost1); + } + // TODO: add tests for the MoreLikeThisQuery }