LUCENE-6510: take path boosts into account when polling TopNSearcher queue

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1682290 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Areek Zillur 2015-05-28 17:46:55 +00:00
parent bc636dea84
commit 40f3361338
2 changed files with 49 additions and 49 deletions

View File

@ -273,7 +273,7 @@ public final class Util {
@Override @Override
public String toString() { public String toString() {
return "input=" + input + " cost=" + cost + "context=" + context + "boost=" + boost; return "input=" + input.get() + " cost=" + cost + "context=" + context + "boost=" + boost;
} }
} }
@ -307,7 +307,8 @@ public final class Util {
private final FST.Arc<T> scratchArc = new FST.Arc<>(); private final FST.Arc<T> scratchArc = new FST.Arc<>();
final Comparator<T> comparator; private final Comparator<T> comparator;
private final Comparator<FSTPath<T>> pathComparator;
TreeSet<FSTPath<T>> queue = null; TreeSet<FSTPath<T>> queue = null;
@ -329,7 +330,7 @@ public final class Util {
this.topN = topN; this.topN = topN;
this.maxQueueDepth = maxQueueDepth; this.maxQueueDepth = maxQueueDepth;
this.comparator = comparator; this.comparator = comparator;
this.pathComparator = pathComparator;
queue = new TreeSet<>(pathComparator); queue = new TreeSet<>(pathComparator);
} }
@ -343,7 +344,7 @@ public final class Util {
if (queue.size() == maxQueueDepth) { if (queue.size() == maxQueueDepth) {
FSTPath<T> bottom = queue.last(); FSTPath<T> bottom = queue.last();
int comp = comparator.compare(cost, bottom.cost); int comp = pathComparator.compare(path, bottom);
if (comp > 0) { if (comp > 0) {
// Doesn't compete // Doesn't compete
return; return;

View File

@ -478,56 +478,55 @@ public class TestContextQuery extends LuceneTestCase {
@Test @Test
public void testRandomContextQueryScoring() throws Exception { public void testRandomContextQueryScoring() throws Exception {
Analyzer analyzer = new MockAnalyzer(random()); Analyzer analyzer = new MockAnalyzer(random());
RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwcWithSuggestField(analyzer, "suggest_field")); try(RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwcWithSuggestField(analyzer, "suggest_field"))) {
int numSuggestions = atLeast(20); int numSuggestions = atLeast(20);
int numContexts = atLeast(5); int numContexts = atLeast(5);
Set<Integer> seenWeights = new HashSet<>(); Set<Integer> seenWeights = new HashSet<>();
List<Entry> expectedEntries = new ArrayList<>(); List<Entry> expectedEntries = new ArrayList<>();
List<CharSequence> contexts = new ArrayList<>(); List<CharSequence> contexts = new ArrayList<>();
for (int i = 1; i <= numContexts; i++) { for (int i = 1; i <= numContexts; i++) {
CharSequence context = TestUtil.randomSimpleString(random(), 10) + i; CharSequence context = TestUtil.randomSimpleString(random(), 10) + i;
contexts.add(context); contexts.add(context);
for (int j = 1; j <= numSuggestions; j++) { for (int j = 1; j <= numSuggestions; j++) {
String suggestion = "sugg_" + TestUtil.randomSimpleString(random(), 10) + j; String suggestion = "sugg_" + TestUtil.randomSimpleString(random(), 10) + j;
int weight = TestUtil.nextInt(random(), 1, 1000 * numContexts * numSuggestions); int weight = TestUtil.nextInt(random(), 1, 1000 * numContexts * numSuggestions);
while (seenWeights.contains(weight)) { while (seenWeights.contains(weight)) {
weight = TestUtil.nextInt(random(), 1, 1000 * numContexts * numSuggestions); weight = TestUtil.nextInt(random(), 1, 1000 * numContexts * numSuggestions);
}
seenWeights.add(weight);
Document document = new Document();
document.add(new ContextSuggestField("suggest_field", Collections.singletonList(context), suggestion, weight));
iw.addDocument(document);
expectedEntries.add(new Entry(suggestion, context.toString(), i * weight));
} }
seenWeights.add(weight); if (rarely()) {
Document document = new Document(); iw.commit();
document.add(new ContextSuggestField("suggest_field", Collections.singletonList(context), suggestion, weight));
iw.addDocument(document);
expectedEntries.add(new Entry(suggestion, context.toString(), i * weight));
}
if (rarely()) {
iw.commit();
}
}
Entry[] expectedResults = expectedEntries.toArray(new Entry[expectedEntries.size()]);
ArrayUtil.introSort(expectedResults, new Comparator<Entry>() {
@Override
public int compare(Entry o1, Entry o2) {
int cmp = Float.compare(o2.value, o1.value);
if (cmp != 0) {
return cmp;
} else {
return o1.output.compareTo(o2.output);
} }
} }
}); Entry[] expectedResults = expectedEntries.toArray(new Entry[expectedEntries.size()]);
DirectoryReader reader = iw.getReader(); ArrayUtil.introSort(expectedResults, new Comparator<Entry>() {
SuggestIndexSearcher suggestIndexSearcher = new SuggestIndexSearcher(reader); @Override
ContextQuery query = new ContextQuery(new PrefixCompletionQuery(analyzer, new Term("suggest_field", "sugg"))); public int compare(Entry o1, Entry o2) {
for (int i = 0; i < contexts.size(); i++) { int cmp = Float.compare(o2.value, o1.value);
query.addContext(contexts.get(i), i + 1); if (cmp != 0) {
return cmp;
} else {
return o1.output.compareTo(o2.output);
}
}
});
try(DirectoryReader reader = iw.getReader()) {
SuggestIndexSearcher suggestIndexSearcher = new SuggestIndexSearcher(reader);
ContextQuery query = new ContextQuery(new PrefixCompletionQuery(analyzer, new Term("suggest_field", "sugg")));
for (int i = 0; i < contexts.size(); i++) {
query.addContext(contexts.get(i), i + 1);
}
TopSuggestDocs suggest = suggestIndexSearcher.suggest(query, 4);
assertSuggestions(suggest, Arrays.copyOfRange(expectedResults, 0, 4));
}
} }
TopSuggestDocs suggest = suggestIndexSearcher.suggest(query, 4);
assertSuggestions(suggest, Arrays.copyOfRange(expectedResults, 0, 4));
reader.close();
iw.close();
} }
} }