Fix duplicate removal when merging completion suggestions (#36996)

The completion suggester ignores the original weight of the suggestion when duplicates are removed. This change fixes this bug and keeps the best weighted suggestion among the duplicates. It also removes the custom implementation of the top docs suggest collector now that https://issues.apache.org/jira/browse/LUCENE-8529 is committed in Lucene.

Closes #35836
This commit is contained in:
Jim Ferenczi 2019-01-15 19:27:31 +01:00 committed by GitHub
parent bea46f7b52
commit f8d80dff7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 117 additions and 221 deletions

View File

@ -20,7 +20,7 @@ package org.elasticsearch.search.suggest.completion;
import org.apache.lucene.analysis.CharArraySet; import org.apache.lucene.analysis.CharArraySet;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.suggest.Lookup; import org.apache.lucene.util.PriorityQueue;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
@ -36,9 +36,9 @@ import org.elasticsearch.search.suggest.Suggest.Suggestion;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -138,31 +138,34 @@ public final class CompletionSuggestion extends Suggest.Suggestion<CompletionSug
return suggestion; return suggestion;
} }
private static final class OptionPriorityQueue extends org.apache.lucene.util.PriorityQueue<Entry.Option> { private static final class OptionPriorityQueue extends PriorityQueue<ShardOptions> {
OptionPriorityQueue(int maxSize) {
private final Comparator<Suggest.Suggestion.Entry.Option> comparator;
OptionPriorityQueue(int maxSize, Comparator<Suggest.Suggestion.Entry.Option> comparator) {
super(maxSize); super(maxSize);
this.comparator = comparator;
} }
@Override @Override
protected boolean lessThan(Entry.Option a, Entry.Option b) { protected boolean lessThan(ShardOptions a, ShardOptions b) {
int cmp = comparator.compare(a, b); return COMPARATOR.compare(a.current, b.current) < 0;
if (cmp != 0) {
return cmp > 0;
} }
return Lookup.CHARSEQUENCE_COMPARATOR.compare(a.getText().string(), b.getText().string()) > 0;
} }
Entry.Option[] get() { private static class ShardOptions {
int size = size(); final Iterator<Entry.Option> optionsIterator;
Entry.Option[] results = new Entry.Option[size]; Entry.Option current;
for (int i = size - 1; i >= 0; i--) {
results[i] = pop(); private ShardOptions(Iterator<Entry.Option> optionsIterator) {
assert optionsIterator.hasNext();
this.optionsIterator = optionsIterator;
this.current = optionsIterator.next();
}
boolean advanceToNextOption() {
if (optionsIterator.hasNext()) {
current = optionsIterator.next();
return true;
} else {
return false;
} }
return results;
} }
} }
@ -177,37 +180,43 @@ public final class CompletionSuggestion extends Suggest.Suggestion<CompletionSug
final CompletionSuggestion leader = (CompletionSuggestion) toReduce.get(0); final CompletionSuggestion leader = (CompletionSuggestion) toReduce.get(0);
final Entry leaderEntry = leader.getEntries().get(0); final Entry leaderEntry = leader.getEntries().get(0);
final String name = leader.getName(); final String name = leader.getName();
int size = leader.getSize();
if (toReduce.size() == 1) { if (toReduce.size() == 1) {
return leader; return leader;
} else { } else {
// combine suggestion entries from participating shards on the coordinating node // combine suggestion entries from participating shards on the coordinating node
// the global top <code>size</code> entries are collected from the shard results // the global top <code>size</code> entries are collected from the shard results
// using a priority queue // using a priority queue
OptionPriorityQueue priorityQueue = new OptionPriorityQueue(leader.getSize(), COMPARATOR); OptionPriorityQueue pq = new OptionPriorityQueue(toReduce.size());
// Dedup duplicate suggestions (based on the surface form) if skip duplicates is activated
final CharArraySet seenSurfaceForms = leader.skipDuplicates ? new CharArraySet(leader.getSize(), false) : null;
for (Suggest.Suggestion<Entry> suggestion : toReduce) { for (Suggest.Suggestion<Entry> suggestion : toReduce) {
assert suggestion.getName().equals(name) : "name should be identical across all suggestions"; assert suggestion.getName().equals(name) : "name should be identical across all suggestions";
for (Entry.Option option : ((CompletionSuggestion) suggestion).getOptions()) { Iterator<Entry.Option> it = ((CompletionSuggestion) suggestion).getOptions().iterator();
if (leader.skipDuplicates) { if (it.hasNext()) {
assert ((CompletionSuggestion) suggestion).skipDuplicates; pq.add(new ShardOptions(it));
String text = option.getText().string();
if (seenSurfaceForms.contains(text)) {
continue;
} }
seenSurfaceForms.add(text);
} }
if (option == priorityQueue.insertWithOverflow(option)) { // Dedup duplicate suggestions (based on the surface form) if skip duplicates is activated
// if the current option has overflown from pq, final CharArraySet seenSurfaceForms = leader.skipDuplicates ? new CharArraySet(leader.getSize(), false) : null;
// we can assume all of the successive options final Entry entry = new Entry(leaderEntry.getText(), leaderEntry.getOffset(), leaderEntry.getLength());
// from this shard result will be overflown as well final List<Entry.Option> options = entry.getOptions();
while (pq.size() > 0) {
ShardOptions top = pq.top();
Entry.Option current = top.current;
if (top.advanceToNextOption()) {
pq.updateTop();
} else {
//options exhausted for this shard
pq.pop();
}
if (leader.skipDuplicates == false ||
seenSurfaceForms.add(current.getText().toString())) {
options.add(current);
if (options.size() >= size) {
break; break;
} }
} }
} }
final CompletionSuggestion suggestion = new CompletionSuggestion(leader.getName(), leader.getSize(), leader.skipDuplicates); final CompletionSuggestion suggestion = new CompletionSuggestion(leader.getName(), leader.getSize(), leader.skipDuplicates);
final Entry entry = new Entry(leaderEntry.getText(), leaderEntry.getOffset(), leaderEntry.getLength());
Collections.addAll(entry.getOptions(), priorityQueue.get());
suggestion.addTerm(entry); suggestion.addTerm(entry);
return suggestion; return suggestion;
} }

View File

@ -18,14 +18,7 @@
*/ */
package org.elasticsearch.search.suggest.completion; package org.elasticsearch.search.suggest.completion;
import org.apache.lucene.analysis.CharArraySet;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.suggest.Lookup;
import org.apache.lucene.search.suggest.document.TopSuggestDocs;
import org.apache.lucene.search.suggest.document.TopSuggestDocsCollector; import org.apache.lucene.search.suggest.document.TopSuggestDocsCollector;
import org.apache.lucene.util.PriorityQueue;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -36,69 +29,13 @@ import java.util.Map;
/** /**
* *
* Custom {@link TopSuggestDocsCollector} that returns top documents from the completion suggester. * Extension of the {@link TopSuggestDocsCollector} that returns top documents from the completion suggester.
* <p> *
* TODO: this should be refactored when https://issues.apache.org/jira/browse/LUCENE-8529 is fixed.
* Unlike the parent class, this collector uses the surface form to tie-break suggestions with identical
* scores.
* <p>
* This collector groups suggestions coming from the same document but matching different contexts * This collector groups suggestions coming from the same document but matching different contexts
* or surface form together. When different contexts or surface forms match the same suggestion form only * or surface form together. When different contexts or surface forms match the same suggestion form only
* the best one per document (sorted by weight) is kept. * the best one per document (sorted by weight) is kept.
* <p>
* This collector is also able to filter duplicate suggestion coming from different documents.
* In order to keep this feature fast, the de-duplication of suggestions with different contexts is done
* only on the top N*num_contexts (where N is the number of documents to return) suggestions per segment.
* This means that skip_duplicates will visit at most N*num_contexts suggestions per segment to find unique suggestions
* that match the input. If more than N*num_contexts suggestions are duplicated with different contexts this collector
* will not be able to return more than one suggestion even when N is greater than 1.
**/ **/
class TopSuggestGroupDocsCollector extends TopSuggestDocsCollector { class TopSuggestGroupDocsCollector extends TopSuggestDocsCollector {
private final class SuggestScoreDocPriorityQueue extends PriorityQueue<TopSuggestDocs.SuggestScoreDoc> {
/**
* Creates a new priority queue of the specified size.
*/
private SuggestScoreDocPriorityQueue(int size) {
super(size);
}
@Override
protected boolean lessThan(TopSuggestDocs.SuggestScoreDoc a, TopSuggestDocs.SuggestScoreDoc b) {
if (a.score == b.score) {
// tie break by completion key
int cmp = Lookup.CHARSEQUENCE_COMPARATOR.compare(a.key, b.key);
// prefer smaller doc id, in case of a tie
return cmp != 0 ? cmp > 0 : a.doc > b.doc;
}
return a.score < b.score;
}
/**
* Returns the top N results in descending order.
*/
public TopSuggestDocs.SuggestScoreDoc[] getResults() {
int size = size();
TopSuggestDocs.SuggestScoreDoc[] res = new TopSuggestDocs.SuggestScoreDoc[size];
for (int i = size - 1; i >= 0; i--) {
res[i] = pop();
}
return res;
}
}
private final SuggestScoreDocPriorityQueue priorityQueue;
private final int num;
/** Only set if we are deduplicating hits: holds all per-segment hits until the end, when we dedup them */
private final List<TopSuggestDocs.SuggestScoreDoc> pendingResults;
/** Only set if we are deduplicating hits: holds all surface forms seen so far in the current segment */
final CharArraySet seenSurfaceForms;
/** Document base offset for the current Leaf */
protected int docBase;
private Map<Integer, List<CharSequence>> docContexts = new HashMap<>(); private Map<Integer, List<CharSequence>> docContexts = new HashMap<>();
/** /**
@ -108,125 +45,24 @@ class TopSuggestGroupDocsCollector extends TopSuggestDocsCollector {
* with corresponding document and weight * with corresponding document and weight
*/ */
TopSuggestGroupDocsCollector(int num, boolean skipDuplicates) { TopSuggestGroupDocsCollector(int num, boolean skipDuplicates) {
super(1, skipDuplicates); super(num, skipDuplicates);
if (num <= 0) {
throw new IllegalArgumentException("'num' must be > 0");
}
this.num = num;
this.priorityQueue = new SuggestScoreDocPriorityQueue(num);
if (skipDuplicates) {
seenSurfaceForms = new CharArraySet(num, false);
pendingResults = new ArrayList<>();
} else {
seenSurfaceForms = null;
pendingResults = null;
}
} }
/** /**
* Returns the contexts associated with the provided <code>doc</code>. * Returns the contexts associated with the provided <code>doc</code>.
*/ */
public List<CharSequence> getContexts(int doc) { List<CharSequence> getContexts(int doc) {
return docContexts.getOrDefault(doc, Collections.emptyList()); return docContexts.getOrDefault(doc, Collections.emptyList());
} }
@Override
protected boolean doSkipDuplicates() {
return seenSurfaceForms != null;
}
@Override
public int getCountToCollect() {
return num;
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
docBase = context.docBase;
if (seenSurfaceForms != null) {
seenSurfaceForms.clear();
// NOTE: this also clears the priorityQueue:
for (TopSuggestDocs.SuggestScoreDoc hit : priorityQueue.getResults()) {
pendingResults.add(hit);
}
}
}
@Override @Override
public void collect(int docID, CharSequence key, CharSequence context, float score) throws IOException { public void collect(int docID, CharSequence key, CharSequence context, float score) throws IOException {
int globalDoc = docID + docBase; int globalDoc = docID + docBase;
boolean isDuplicate = docContexts.containsKey(globalDoc); boolean isNewDoc = docContexts.containsKey(globalDoc) == false;
List<CharSequence> contexts = docContexts.computeIfAbsent(globalDoc, k -> new ArrayList<>()); List<CharSequence> contexts = docContexts.computeIfAbsent(globalDoc, k -> new ArrayList<>());
if (context != null) {
contexts.add(context); contexts.add(context);
} if (isNewDoc) {
if (isDuplicate) { super.collect(docID, key, context, score);
return;
}
TopSuggestDocs.SuggestScoreDoc current = new TopSuggestDocs.SuggestScoreDoc(globalDoc, key, context, score);
if (current == priorityQueue.insertWithOverflow(current)) {
// if the current SuggestScoreDoc has overflown from pq,
// we can assume all of the successive collections from
// this leaf will be overflown as well
// TODO: reuse the overflow instance?
throw new CollectionTerminatedException();
} }
} }
@Override
public TopSuggestDocs get() throws IOException {
TopSuggestDocs.SuggestScoreDoc[] suggestScoreDocs;
if (seenSurfaceForms != null) {
// NOTE: this also clears the priorityQueue:
for (TopSuggestDocs.SuggestScoreDoc hit : priorityQueue.getResults()) {
pendingResults.add(hit);
}
// Deduplicate all hits: we already dedup'd efficiently within each segment by
// truncating the FST top paths search, but across segments there may still be dups:
seenSurfaceForms.clear();
// TODO: we could use a priority queue here to make cost O(N * log(num)) instead of O(N * log(N)), where N = O(num *
// numSegments), but typically numSegments is smallish and num is smallish so this won't matter much in practice:
Collections.sort(pendingResults,
(a, b) -> {
// sort by higher score
int cmp = Float.compare(b.score, a.score);
if (cmp == 0) {
// tie break by completion key
cmp = Lookup.CHARSEQUENCE_COMPARATOR.compare(a.key, b.key);
if (cmp == 0) {
// prefer smaller doc id, in case of a tie
cmp = Integer.compare(a.doc, b.doc);
}
}
return cmp;
});
List<TopSuggestDocs.SuggestScoreDoc> hits = new ArrayList<>();
for (TopSuggestDocs.SuggestScoreDoc hit : pendingResults) {
if (seenSurfaceForms.contains(hit.key) == false) {
seenSurfaceForms.add(hit.key);
hits.add(hit);
if (hits.size() == num) {
break;
}
}
}
suggestScoreDocs = hits.toArray(new TopSuggestDocs.SuggestScoreDoc[0]);
} else {
suggestScoreDocs = priorityQueue.getResults();
}
if (suggestScoreDocs.length > 0) {
return new TopSuggestDocs(new TotalHits(suggestScoreDocs.length, TotalHits.Relation.EQUAL_TO), suggestScoreDocs);
} else {
return TopSuggestDocs.EMPTY;
}
}
} }

View File

@ -54,6 +54,7 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Comparator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
@ -890,29 +891,40 @@ public class CompletionSuggestSearchIT extends ESIntegTestCase {
int numDocs = randomIntBetween(10, 100); int numDocs = randomIntBetween(10, 100);
int numUnique = randomIntBetween(1, numDocs); int numUnique = randomIntBetween(1, numDocs);
List<IndexRequestBuilder> indexRequestBuilders = new ArrayList<>(); List<IndexRequestBuilder> indexRequestBuilders = new ArrayList<>();
int[] weights = new int[numUnique];
Integer[] termIds = new Integer[numUnique];
for (int i = 1; i <= numDocs; i++) { for (int i = 1; i <= numDocs; i++) {
int id = i % numUnique; int id = i % numUnique;
indexRequestBuilders.add(client().prepareIndex(INDEX, TYPE, "" + i) termIds[id] = id;
int weight = randomIntBetween(0, 100);
weights[id] = Math.max(weight, weights[id]);
String suggestion = "suggestion-" + String.format(Locale.ENGLISH, "%03d" , id);
indexRequestBuilders.add(client().prepareIndex(INDEX, TYPE)
.setSource(jsonBuilder() .setSource(jsonBuilder()
.startObject() .startObject()
.startObject(FIELD) .startObject(FIELD)
.field("input", "suggestion" + id) .field("input", suggestion)
.field("weight", id) .field("weight", weight)
.endObject() .endObject()
.endObject() .endObject()
)); ));
} }
String[] expected = new String[numUnique];
int sugg = numUnique - 1;
for (int i = 0; i < numUnique; i++) {
expected[i] = "suggestion" + sugg--;
}
indexRandom(true, indexRequestBuilders); indexRandom(true, indexRequestBuilders);
CompletionSuggestionBuilder completionSuggestionBuilder =
SuggestBuilders.completionSuggestion(FIELD).prefix("sugg").skipDuplicates(true).size(numUnique); Arrays.sort(termIds,
Comparator.comparingInt(o -> weights[(int) o]).reversed().thenComparingInt(a -> (int) a));
String[] expected = new String[numUnique];
for (int i = 0; i < termIds.length; i++) {
expected[i] = "suggestion-" + String.format(Locale.ENGLISH, "%03d" , termIds[i]);
}
CompletionSuggestionBuilder completionSuggestionBuilder = SuggestBuilders.completionSuggestion(FIELD)
.prefix("sugg")
.skipDuplicates(true)
.size(numUnique);
SearchResponse searchResponse = client().prepareSearch(INDEX) SearchResponse searchResponse = client().prepareSearch(INDEX)
.suggest(new SuggestBuilder().addSuggestion("suggestions", completionSuggestionBuilder)).get(); .suggest(new SuggestBuilder().addSuggestion("suggestions", completionSuggestionBuilder))
.get();
assertSuggestions(searchResponse, true, "suggestions", expected); assertSuggestions(searchResponse, true, "suggestions", expected);
} }

View File

@ -26,7 +26,9 @@ import org.elasticsearch.test.ESTestCase;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import static org.elasticsearch.search.suggest.Suggest.COMPARATOR;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.lessThanOrEqualTo;
@ -61,4 +63,41 @@ public class CompletionSuggestionTests extends ESTestCase {
count++; count++;
} }
} }
public void testToReduceWithDuplicates() {
List<Suggest.Suggestion<CompletionSuggestion.Entry>> shardSuggestions = new ArrayList<>();
int nShards = randomIntBetween(2, 10);
String name = randomAlphaOfLength(10);
int size = randomIntBetween(10, 100);
int totalResults = size * nShards;
int numSurfaceForms = randomIntBetween(1, size);
String[] surfaceForms = new String[numSurfaceForms];
for (int i = 0; i < numSurfaceForms; i++) {
surfaceForms[i] = randomAlphaOfLength(20);
}
List<CompletionSuggestion.Entry.Option> options = new ArrayList<>();
for (int i = 0; i < nShards; i++) {
CompletionSuggestion suggestion = new CompletionSuggestion(name, size, true);
CompletionSuggestion.Entry entry = new CompletionSuggestion.Entry(new Text(""), 0, 0);
suggestion.addTerm(entry);
int maxScore = randomIntBetween(totalResults, totalResults*2);
for (int j = 0; j < size; j++) {
String surfaceForm = randomFrom(surfaceForms);
CompletionSuggestion.Entry.Option newOption =
new CompletionSuggestion.Entry.Option(j, new Text(surfaceForm), maxScore - j, Collections.emptyMap());
entry.addOption(newOption);
options.add(newOption);
}
shardSuggestions.add(suggestion);
}
List<CompletionSuggestion.Entry.Option> expected = options.stream()
.sorted(COMPARATOR)
.distinct()
.limit(size)
.collect(Collectors.toList());
CompletionSuggestion reducedSuggestion = CompletionSuggestion.reduceTo(shardSuggestions);
assertNotNull(reducedSuggestion);
assertThat(reducedSuggestion.getOptions().size(), lessThanOrEqualTo(size));
assertEquals(expected, reducedSuggestion.getOptions());
}
} }