diff --git a/server/src/main/java/org/elasticsearch/search/suggest/completion/CompletionSuggestion.java b/server/src/main/java/org/elasticsearch/search/suggest/completion/CompletionSuggestion.java index 0fb7e4b3f43..34c191df3b3 100644 --- a/server/src/main/java/org/elasticsearch/search/suggest/completion/CompletionSuggestion.java +++ b/server/src/main/java/org/elasticsearch/search/suggest/completion/CompletionSuggestion.java @@ -20,7 +20,7 @@ package org.elasticsearch.search.suggest.completion; import org.apache.lucene.analysis.CharArraySet; import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.suggest.Lookup; +import org.apache.lucene.util.PriorityQueue; import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -36,9 +36,9 @@ import org.elasticsearch.search.suggest.Suggest.Suggestion; import java.io.IOException; import java.util.Collections; -import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -138,31 +138,34 @@ public final class CompletionSuggestion extends Suggest.Suggestion { - - private final Comparator comparator; - - OptionPriorityQueue(int maxSize, Comparator comparator) { + private static final class OptionPriorityQueue extends PriorityQueue { + OptionPriorityQueue(int maxSize) { super(maxSize); - this.comparator = comparator; } @Override - protected boolean lessThan(Entry.Option a, Entry.Option b) { - int cmp = comparator.compare(a, b); - if (cmp != 0) { - return cmp > 0; - } - return Lookup.CHARSEQUENCE_COMPARATOR.compare(a.getText().string(), b.getText().string()) > 0; + protected boolean lessThan(ShardOptions a, ShardOptions b) { + return COMPARATOR.compare(a.current, b.current) < 0; + } + } + + private static class ShardOptions { + final Iterator optionsIterator; + Entry.Option current; + + private ShardOptions(Iterator optionsIterator) { + assert optionsIterator.hasNext(); + this.optionsIterator = optionsIterator; + this.current = optionsIterator.next(); } - Entry.Option[] get() { - int size = size(); - Entry.Option[] results = new Entry.Option[size]; - for (int i = size - 1; i >= 0; i--) { - results[i] = pop(); + boolean advanceToNextOption() { + if (optionsIterator.hasNext()) { + current = optionsIterator.next(); + return true; + } else { + return false; } - return results; } } @@ -177,37 +180,43 @@ public final class CompletionSuggestion extends Suggest.Suggestionsize entries are collected from the shard results // using a priority queue - OptionPriorityQueue priorityQueue = new OptionPriorityQueue(leader.getSize(), COMPARATOR); - // Dedup duplicate suggestions (based on the surface form) if skip duplicates is activated - final CharArraySet seenSurfaceForms = leader.skipDuplicates ? new CharArraySet(leader.getSize(), false) : null; + OptionPriorityQueue pq = new OptionPriorityQueue(toReduce.size()); for (Suggest.Suggestion suggestion : toReduce) { assert suggestion.getName().equals(name) : "name should be identical across all suggestions"; - for (Entry.Option option : ((CompletionSuggestion) suggestion).getOptions()) { - if (leader.skipDuplicates) { - assert ((CompletionSuggestion) suggestion).skipDuplicates; - String text = option.getText().string(); - if (seenSurfaceForms.contains(text)) { - continue; - } - seenSurfaceForms.add(text); - } - if (option == priorityQueue.insertWithOverflow(option)) { - // if the current option has overflown from pq, - // we can assume all of the successive options - // from this shard result will be overflown as well + Iterator it = ((CompletionSuggestion) suggestion).getOptions().iterator(); + if (it.hasNext()) { + pq.add(new ShardOptions(it)); + } + } + // Dedup duplicate suggestions (based on the surface form) if skip duplicates is activated + final CharArraySet seenSurfaceForms = leader.skipDuplicates ? new CharArraySet(leader.getSize(), false) : null; + final Entry entry = new Entry(leaderEntry.getText(), leaderEntry.getOffset(), leaderEntry.getLength()); + final List options = entry.getOptions(); + while (pq.size() > 0) { + ShardOptions top = pq.top(); + Entry.Option current = top.current; + if (top.advanceToNextOption()) { + pq.updateTop(); + } else { + //options exhausted for this shard + pq.pop(); + } + if (leader.skipDuplicates == false || + seenSurfaceForms.add(current.getText().toString())) { + options.add(current); + if (options.size() >= size) { break; } } } final CompletionSuggestion suggestion = new CompletionSuggestion(leader.getName(), leader.getSize(), leader.skipDuplicates); - final Entry entry = new Entry(leaderEntry.getText(), leaderEntry.getOffset(), leaderEntry.getLength()); - Collections.addAll(entry.getOptions(), priorityQueue.get()); suggestion.addTerm(entry); return suggestion; } diff --git a/server/src/main/java/org/elasticsearch/search/suggest/completion/TopSuggestGroupDocsCollector.java b/server/src/main/java/org/elasticsearch/search/suggest/completion/TopSuggestGroupDocsCollector.java index 879b8114b2a..3dfb38bef9d 100644 --- a/server/src/main/java/org/elasticsearch/search/suggest/completion/TopSuggestGroupDocsCollector.java +++ b/server/src/main/java/org/elasticsearch/search/suggest/completion/TopSuggestGroupDocsCollector.java @@ -18,14 +18,7 @@ */ package org.elasticsearch.search.suggest.completion; -import org.apache.lucene.analysis.CharArraySet; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.CollectionTerminatedException; -import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.suggest.Lookup; -import org.apache.lucene.search.suggest.document.TopSuggestDocs; import org.apache.lucene.search.suggest.document.TopSuggestDocsCollector; -import org.apache.lucene.util.PriorityQueue; import java.io.IOException; import java.util.ArrayList; @@ -36,69 +29,13 @@ import java.util.Map; /** * - * Custom {@link TopSuggestDocsCollector} that returns top documents from the completion suggester. - *

- * TODO: this should be refactored when https://issues.apache.org/jira/browse/LUCENE-8529 is fixed. - * Unlike the parent class, this collector uses the surface form to tie-break suggestions with identical - * scores. - *

+ * Extension of the {@link TopSuggestDocsCollector} that returns top documents from the completion suggester. + * * This collector groups suggestions coming from the same document but matching different contexts * or surface form together. When different contexts or surface forms match the same suggestion form only * the best one per document (sorted by weight) is kept. - *

- * This collector is also able to filter duplicate suggestion coming from different documents. - * In order to keep this feature fast, the de-duplication of suggestions with different contexts is done - * only on the top N*num_contexts (where N is the number of documents to return) suggestions per segment. - * This means that skip_duplicates will visit at most N*num_contexts suggestions per segment to find unique suggestions - * that match the input. If more than N*num_contexts suggestions are duplicated with different contexts this collector - * will not be able to return more than one suggestion even when N is greater than 1. **/ class TopSuggestGroupDocsCollector extends TopSuggestDocsCollector { - private final class SuggestScoreDocPriorityQueue extends PriorityQueue { - /** - * Creates a new priority queue of the specified size. - */ - private SuggestScoreDocPriorityQueue(int size) { - super(size); - } - - @Override - protected boolean lessThan(TopSuggestDocs.SuggestScoreDoc a, TopSuggestDocs.SuggestScoreDoc b) { - if (a.score == b.score) { - // tie break by completion key - int cmp = Lookup.CHARSEQUENCE_COMPARATOR.compare(a.key, b.key); - // prefer smaller doc id, in case of a tie - return cmp != 0 ? cmp > 0 : a.doc > b.doc; - } - return a.score < b.score; - } - - /** - * Returns the top N results in descending order. - */ - public TopSuggestDocs.SuggestScoreDoc[] getResults() { - int size = size(); - TopSuggestDocs.SuggestScoreDoc[] res = new TopSuggestDocs.SuggestScoreDoc[size]; - for (int i = size - 1; i >= 0; i--) { - res[i] = pop(); - } - return res; - } - } - - - private final SuggestScoreDocPriorityQueue priorityQueue; - private final int num; - - /** Only set if we are deduplicating hits: holds all per-segment hits until the end, when we dedup them */ - private final List pendingResults; - - /** Only set if we are deduplicating hits: holds all surface forms seen so far in the current segment */ - final CharArraySet seenSurfaceForms; - - /** Document base offset for the current Leaf */ - protected int docBase; - private Map> docContexts = new HashMap<>(); /** @@ -108,125 +45,24 @@ class TopSuggestGroupDocsCollector extends TopSuggestDocsCollector { * with corresponding document and weight */ TopSuggestGroupDocsCollector(int num, boolean skipDuplicates) { - super(1, skipDuplicates); - if (num <= 0) { - throw new IllegalArgumentException("'num' must be > 0"); - } - this.num = num; - this.priorityQueue = new SuggestScoreDocPriorityQueue(num); - if (skipDuplicates) { - seenSurfaceForms = new CharArraySet(num, false); - pendingResults = new ArrayList<>(); - } else { - seenSurfaceForms = null; - pendingResults = null; - } + super(num, skipDuplicates); } /** * Returns the contexts associated with the provided doc. */ - public List getContexts(int doc) { + List getContexts(int doc) { return docContexts.getOrDefault(doc, Collections.emptyList()); } - @Override - protected boolean doSkipDuplicates() { - return seenSurfaceForms != null; - } - - @Override - public int getCountToCollect() { - return num; - } - - @Override - protected void doSetNextReader(LeafReaderContext context) throws IOException { - docBase = context.docBase; - if (seenSurfaceForms != null) { - seenSurfaceForms.clear(); - // NOTE: this also clears the priorityQueue: - for (TopSuggestDocs.SuggestScoreDoc hit : priorityQueue.getResults()) { - pendingResults.add(hit); - } - } - } - @Override public void collect(int docID, CharSequence key, CharSequence context, float score) throws IOException { int globalDoc = docID + docBase; - boolean isDuplicate = docContexts.containsKey(globalDoc); + boolean isNewDoc = docContexts.containsKey(globalDoc) == false; List contexts = docContexts.computeIfAbsent(globalDoc, k -> new ArrayList<>()); - if (context != null) { - contexts.add(context); - } - if (isDuplicate) { - return; - } - TopSuggestDocs.SuggestScoreDoc current = new TopSuggestDocs.SuggestScoreDoc(globalDoc, key, context, score); - if (current == priorityQueue.insertWithOverflow(current)) { - // if the current SuggestScoreDoc has overflown from pq, - // we can assume all of the successive collections from - // this leaf will be overflown as well - // TODO: reuse the overflow instance? - throw new CollectionTerminatedException(); + contexts.add(context); + if (isNewDoc) { + super.collect(docID, key, context, score); } } - - @Override - public TopSuggestDocs get() throws IOException { - - TopSuggestDocs.SuggestScoreDoc[] suggestScoreDocs; - - if (seenSurfaceForms != null) { - // NOTE: this also clears the priorityQueue: - for (TopSuggestDocs.SuggestScoreDoc hit : priorityQueue.getResults()) { - pendingResults.add(hit); - } - - // Deduplicate all hits: we already dedup'd efficiently within each segment by - // truncating the FST top paths search, but across segments there may still be dups: - seenSurfaceForms.clear(); - - // TODO: we could use a priority queue here to make cost O(N * log(num)) instead of O(N * log(N)), where N = O(num * - // numSegments), but typically numSegments is smallish and num is smallish so this won't matter much in practice: - - Collections.sort(pendingResults, - (a, b) -> { - // sort by higher score - int cmp = Float.compare(b.score, a.score); - if (cmp == 0) { - // tie break by completion key - cmp = Lookup.CHARSEQUENCE_COMPARATOR.compare(a.key, b.key); - if (cmp == 0) { - // prefer smaller doc id, in case of a tie - cmp = Integer.compare(a.doc, b.doc); - } - } - return cmp; - }); - - List hits = new ArrayList<>(); - - for (TopSuggestDocs.SuggestScoreDoc hit : pendingResults) { - if (seenSurfaceForms.contains(hit.key) == false) { - seenSurfaceForms.add(hit.key); - hits.add(hit); - if (hits.size() == num) { - break; - } - } - } - suggestScoreDocs = hits.toArray(new TopSuggestDocs.SuggestScoreDoc[0]); - } else { - suggestScoreDocs = priorityQueue.getResults(); - } - - if (suggestScoreDocs.length > 0) { - return new TopSuggestDocs(new TotalHits(suggestScoreDocs.length, TotalHits.Relation.EQUAL_TO), suggestScoreDocs); - } else { - return TopSuggestDocs.EMPTY; - } - } - } diff --git a/server/src/test/java/org/elasticsearch/search/suggest/CompletionSuggestSearchIT.java b/server/src/test/java/org/elasticsearch/search/suggest/CompletionSuggestSearchIT.java index d4a0edf914e..f8a333cda69 100644 --- a/server/src/test/java/org/elasticsearch/search/suggest/CompletionSuggestSearchIT.java +++ b/server/src/test/java/org/elasticsearch/search/suggest/CompletionSuggestSearchIT.java @@ -54,6 +54,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Comparator; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; @@ -890,29 +891,40 @@ public class CompletionSuggestSearchIT extends ESIntegTestCase { int numDocs = randomIntBetween(10, 100); int numUnique = randomIntBetween(1, numDocs); List indexRequestBuilders = new ArrayList<>(); + int[] weights = new int[numUnique]; + Integer[] termIds = new Integer[numUnique]; for (int i = 1; i <= numDocs; i++) { int id = i % numUnique; - indexRequestBuilders.add(client().prepareIndex(INDEX, TYPE, "" + i) + termIds[id] = id; + int weight = randomIntBetween(0, 100); + weights[id] = Math.max(weight, weights[id]); + String suggestion = "suggestion-" + String.format(Locale.ENGLISH, "%03d" , id); + indexRequestBuilders.add(client().prepareIndex(INDEX, TYPE) .setSource(jsonBuilder() .startObject() .startObject(FIELD) - .field("input", "suggestion" + id) - .field("weight", id) + .field("input", suggestion) + .field("weight", weight) .endObject() .endObject() )); } - String[] expected = new String[numUnique]; - int sugg = numUnique - 1; - for (int i = 0; i < numUnique; i++) { - expected[i] = "suggestion" + sugg--; - } indexRandom(true, indexRequestBuilders); - CompletionSuggestionBuilder completionSuggestionBuilder = - SuggestBuilders.completionSuggestion(FIELD).prefix("sugg").skipDuplicates(true).size(numUnique); + + Arrays.sort(termIds, + Comparator.comparingInt(o -> weights[(int) o]).reversed().thenComparingInt(a -> (int) a)); + String[] expected = new String[numUnique]; + for (int i = 0; i < termIds.length; i++) { + expected[i] = "suggestion-" + String.format(Locale.ENGLISH, "%03d" , termIds[i]); + } + CompletionSuggestionBuilder completionSuggestionBuilder = SuggestBuilders.completionSuggestion(FIELD) + .prefix("sugg") + .skipDuplicates(true) + .size(numUnique); SearchResponse searchResponse = client().prepareSearch(INDEX) - .suggest(new SuggestBuilder().addSuggestion("suggestions", completionSuggestionBuilder)).get(); + .suggest(new SuggestBuilder().addSuggestion("suggestions", completionSuggestionBuilder)) + .get(); assertSuggestions(searchResponse, true, "suggestions", expected); } diff --git a/server/src/test/java/org/elasticsearch/search/suggest/completion/CompletionSuggestionTests.java b/server/src/test/java/org/elasticsearch/search/suggest/completion/CompletionSuggestionTests.java index bfa9bcb8093..8f7b731f7ae 100644 --- a/server/src/test/java/org/elasticsearch/search/suggest/completion/CompletionSuggestionTests.java +++ b/server/src/test/java/org/elasticsearch/search/suggest/completion/CompletionSuggestionTests.java @@ -26,7 +26,9 @@ import org.elasticsearch.test.ESTestCase; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; +import static org.elasticsearch.search.suggest.Suggest.COMPARATOR; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -61,4 +63,41 @@ public class CompletionSuggestionTests extends ESTestCase { count++; } } + + public void testToReduceWithDuplicates() { + List> shardSuggestions = new ArrayList<>(); + int nShards = randomIntBetween(2, 10); + String name = randomAlphaOfLength(10); + int size = randomIntBetween(10, 100); + int totalResults = size * nShards; + int numSurfaceForms = randomIntBetween(1, size); + String[] surfaceForms = new String[numSurfaceForms]; + for (int i = 0; i < numSurfaceForms; i++) { + surfaceForms[i] = randomAlphaOfLength(20); + } + List options = new ArrayList<>(); + for (int i = 0; i < nShards; i++) { + CompletionSuggestion suggestion = new CompletionSuggestion(name, size, true); + CompletionSuggestion.Entry entry = new CompletionSuggestion.Entry(new Text(""), 0, 0); + suggestion.addTerm(entry); + int maxScore = randomIntBetween(totalResults, totalResults*2); + for (int j = 0; j < size; j++) { + String surfaceForm = randomFrom(surfaceForms); + CompletionSuggestion.Entry.Option newOption = + new CompletionSuggestion.Entry.Option(j, new Text(surfaceForm), maxScore - j, Collections.emptyMap()); + entry.addOption(newOption); + options.add(newOption); + } + shardSuggestions.add(suggestion); + } + List expected = options.stream() + .sorted(COMPARATOR) + .distinct() + .limit(size) + .collect(Collectors.toList()); + CompletionSuggestion reducedSuggestion = CompletionSuggestion.reduceTo(shardSuggestions); + assertNotNull(reducedSuggestion); + assertThat(reducedSuggestion.getOptions().size(), lessThanOrEqualTo(size)); + assertEquals(expected, reducedSuggestion.getOptions()); + } }