LUCENE-10539: Return a stream of completions from FSTCompletion. (#844)

This commit is contained in:
Dawid Weiss 2022-04-29 21:35:35 +02:00 committed by GitHub
parent 75aadb9589
commit 05de9085ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 204 additions and 130 deletions

View File

@ -76,6 +76,8 @@ API Changes
New Features
---------------------
* LUCENE-10539: Return a stream of completions from FSTCompletion. (Dawid Weiss)
* LUCENE-10385: Implement Weight#count on IndexSortSortedNumericDocValuesRangeQuery
to speed up computing the number of hits when possible. (Lu Xugang, Luca Cavanna, Adrien Grand)

View File

@ -17,8 +17,19 @@
package org.apache.lucene.search.suggest.fst;
import java.io.IOException;
import java.util.*;
import org.apache.lucene.util.*;
import java.io.UncheckedIOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.FST.Arc;
@ -53,7 +64,7 @@ public class FSTCompletion {
return utf8.utf8ToString() + "/" + bucket;
}
/** @see BytesRef#compareTo(BytesRef) */
/** Completions are equal when their {@link #utf8} images are equal (bucket is not compared). */
@Override
public int compareTo(Completion o) {
return this.utf8.compareTo(o.utf8);
@ -184,110 +195,174 @@ public class FSTCompletion {
return EMPTY_RESULT;
}
if (!higherWeightsFirst && rootArcs.length > 1) {
// We could emit a warning here (?). An optimal strategy for
// alphabetically sorted
// suggestions would be to add them with a constant weight -- this saves
// unnecessary
// traversals and sorting.
return lookup(key).sorted().limit(num).collect(Collectors.toList());
} else {
return lookup(key).limit(num).collect(Collectors.toList());
}
}
/**
* Lookup suggestions to <code>key</code> and return a stream of matching completions. The stream
* fetches completions dynamically - it can be filtered and limited to acquire the desired number
* of completions without collecting all of them.
*
* @param key The prefix to which suggestions should be sought.
* @return Returns the suggestions
*/
public Stream<Completion> lookup(CharSequence key) {
if (key.length() == 0 || automaton == null) {
return Stream.empty();
}
try {
BytesRef keyUtf8 = new BytesRef(key);
if (!higherWeightsFirst && rootArcs.length > 1) {
// We could emit a warning here (?). An optimal strategy for
// alphabetically sorted
// suggestions would be to add them with a constant weight -- this saves
// unnecessary
// traversals and sorting.
return lookupSortedAlphabetically(keyUtf8, num);
} else {
return lookupSortedByWeight(keyUtf8, num, false);
}
return lookupSortedByWeight(new BytesRef(key));
} catch (IOException e) {
// Should never happen, but anyway.
throw new RuntimeException(e);
}
}
/**
* Lookup suggestions sorted alphabetically <b>if weights are not constant</b>. This is a
* workaround: in general, use constant weights for alphabetically sorted result.
*/
private List<Completion> lookupSortedAlphabetically(BytesRef key, int num) throws IOException {
// Greedily get num results from each weight branch.
List<Completion> res = lookupSortedByWeight(key, num, true);
/** Lookup suggestions sorted by weight (descending order). */
private Stream<Completion> lookupSortedByWeight(BytesRef key) throws IOException {
// Sort and trim.
Collections.sort(res);
if (res.size() > num) {
res = res.subList(0, num);
}
return res;
}
/**
* Lookup suggestions sorted by weight (descending order).
*
* @param collectAll If <code>true</code>, the routine terminates immediately when <code>num
* </code> suggestions have been collected. If <code>false</code>, it will collect suggestions
* from all weight arcs (needed for {@link #lookupSortedAlphabetically}.
*/
private ArrayList<Completion> lookupSortedByWeight(BytesRef key, int num, boolean collectAll)
throws IOException {
// Don't overallocate the results buffers. This also serves the purpose of
// allowing the user of this class to request all matches using Integer.MAX_VALUE as
// the number of results.
final ArrayList<Completion> res = new ArrayList<>(Math.min(10, num));
final BytesRef output = BytesRef.deepCopyOf(key);
for (int i = 0; i < rootArcs.length; i++) {
final FST.Arc<Object> rootArc = rootArcs[i];
final FST.Arc<Object> arc = new FST.Arc<>().copyFrom(rootArc);
// Descend into the automaton using the key as prefix.
if (descendWithPrefix(arc, key)) {
// A subgraph starting from the current node has the completions
// of the key prefix. The arc we're at is the last key's byte,
// so we will collect it too.
output.length = key.length - 1;
if (collect(res, num, rootArc.label(), output, arc) && !collectAll) {
// We have enough suggestions to return immediately. Keep on looking
// for an
// exact match, if requested.
if (exactFirst) {
if (!checkExistingAndReorder(res, key)) {
int exactMatchBucket = getExactMatchStartingFromRootArc(i, key);
if (exactMatchBucket != -1) {
// Insert as the first result and truncate at num.
while (res.size() >= num) {
res.remove(res.size() - 1);
}
res.add(0, new Completion(key, exactMatchBucket));
}
}
}
// Look for an exact match first.
Completion exactCompletion;
if (exactFirst) {
Completion c = null;
for (int i = 0; i < rootArcs.length; i++) {
int exactMatchBucket = getExactMatchStartingFromRootArc(i, key);
if (exactMatchBucket != -1) {
// root arcs are sorted by decreasing weight so any first exact match will always win.
c = new Completion(key, exactMatchBucket);
break;
}
}
exactCompletion = c;
} else {
exactCompletion = null;
}
return res;
Stream<Completion> stream =
IntStream.range(0, rootArcs.length)
.boxed()
.flatMap(
i -> {
try {
final FST.Arc<Object> rootArc = rootArcs[i];
final FST.Arc<Object> arc = new FST.Arc<>().copyFrom(rootArc);
if (descendWithPrefix(arc, key)) {
// A subgraph starting from the current node has the completions
// of the key prefix. The arc we're at is the last key's byte,
// so we will collect it too.
final BytesRef output = BytesRef.deepCopyOf(key);
output.length = key.length;
return completionStream(output, rootArc.label(), arc);
} else {
return Stream.empty();
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});
// if requested, return the exact completion first and omit it in any further completions.
if (exactFirst && exactCompletion != null) {
stream =
Stream.concat(
Stream.of(exactCompletion),
stream.filter(completion -> exactCompletion.compareTo(completion) != 0));
}
return stream;
}
/**
* Checks if the list of
* {@link org.apache.lucene.search.suggest.Lookup.LookupResult}s already has a
* <code>key</code>. If so, reorders that
* {@link org.apache.lucene.search.suggest.Lookup.LookupResult} to the first
* position.
*
* @return Returns <code>true<code> if and only if <code>list</code> contained
* <code>key</code>.
*/
private boolean checkExistingAndReorder(ArrayList<Completion> list, BytesRef key) {
// We assume list does not have duplicates (because of how the FST is created).
for (int i = list.size(); --i >= 0; ) {
if (key.equals(list.get(i).utf8)) {
// Key found. Unless already at i==0, remove it and push up front so
// that the ordering
// remains identical with the exception of the exact match.
list.add(0, list.remove(i));
return true;
/** Return a stream of all completions starting from the provided arc. */
private Stream<? extends Completion> completionStream(
BytesRef output, int bucket, Arc<Object> fromArc) throws IOException {
FST.BytesReader fstReader = automaton.getBytesReader();
class State {
Arc<Object> arc;
int outputLength;
State(Arc<Object> arc, int outputLength) throws IOException {
this.arc = automaton.readFirstTargetArc(arc, new Arc<>(), fstReader);
this.outputLength = outputLength;
}
}
return false;
ArrayDeque<State> states = new ArrayDeque<>();
states.addLast(new State(fromArc, output.length));
return StreamSupport.stream(
new Spliterator<>() {
@Override
public boolean tryAdvance(Consumer<? super Completion> action) {
try {
while (!states.isEmpty()) {
var state = states.peekLast();
output.length = state.outputLength;
var arc = state.arc;
var arcLabel = arc.label();
if (arcLabel == FST.END_LABEL) {
Completion completion = new Completion(output, bucket);
action.accept(completion);
if (arc.isLast()) {
states.removeLast();
} else {
automaton.readNextArc(arc, fstReader);
}
return true;
} else {
assert output.offset == 0;
if (output.length == output.bytes.length) {
output.bytes = ArrayUtil.grow(output.bytes);
}
output.bytes[output.length++] = (byte) arcLabel;
State newState = new State(arc, output.length);
if (arc.isLast()) {
states.removeLast();
} else {
automaton.readNextArc(arc, fstReader);
}
states.addLast(newState);
}
}
return false;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public Spliterator<Completion> trySplit() {
// Don't try to split.
return null;
}
@Override
public long estimateSize() {
return Long.MAX_VALUE;
}
@Override
public int characteristics() {
return Spliterator.NONNULL | Spliterator.ORDERED;
}
},
false);
}
/**
@ -312,41 +387,6 @@ public class FSTCompletion {
return true;
}
/**
* Recursive collect lookup results from the automaton subgraph starting at <code>arc</code>.
*
* @param num Maximum number of results needed (early termination).
*/
private boolean collect(
List<Completion> res, int num, int bucket, BytesRef output, Arc<Object> arc)
throws IOException {
if (output.length == output.bytes.length) {
output.bytes = ArrayUtil.grow(output.bytes);
}
assert output.offset == 0;
output.bytes[output.length++] = (byte) arc.label();
FST.BytesReader fstReader = automaton.getBytesReader();
automaton.readFirstTargetArc(arc, arc, fstReader);
while (true) {
if (arc.label() == FST.END_LABEL) {
res.add(new Completion(output, bucket));
if (res.size() >= num) return true;
} else {
int save = output.length;
if (collect(res, num, bucket, output, new Arc<>().copyFrom(arc))) {
return true;
}
output.length = save;
}
if (arc.isLast()) {
break;
}
automaton.readNextArc(arc, fstReader);
}
return false;
}
/** Returns the bucket count (discretization thresholds). */
public int getBucketCount() {
return rootArcs.length;

View File

@ -17,14 +17,23 @@
package org.apache.lucene.search.suggest.fst;
import java.nio.charset.StandardCharsets;
import java.util.*;
import org.apache.lucene.search.suggest.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import java.util.stream.Collectors;
import org.apache.lucene.search.suggest.Input;
import org.apache.lucene.search.suggest.InputArrayIterator;
import org.apache.lucene.search.suggest.Lookup.LookupResult;
import org.apache.lucene.search.suggest.SuggestRebuildTestUtil;
import org.apache.lucene.search.suggest.TestLookupBenchmark;
import org.apache.lucene.search.suggest.fst.FSTCompletion.Completion;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.*;
import org.apache.lucene.util.BytesRef;
/** Unit tests for {@link FSTCompletion}. */
public class TestFSTCompletion extends LuceneTestCase {
@ -81,6 +90,20 @@ public class TestFSTCompletion extends LuceneTestCase {
assertMatchEquals(completion.lookup(stringToCharSequence("one"), 2), "one/0.0", "oneness/1.0");
}
public void testCompletionStream() throws Exception {
var completions =
completion
.lookup("fo")
.filter(completion -> !completion.utf8.utf8ToString().contains("fourteen"))
.sorted(
Comparator.comparing(
completion -> completion.utf8.utf8ToString().toLowerCase(Locale.ROOT)))
.collect(Collectors.toList());
assertMatchEquals(
completions, "foundation/1", "four/0", "fourblah/1", "fourier/0", "fourty/1.0");
}
public void testExactMatchReordering() throws Exception {
// Check reordering of exact matches.
assertMatchEquals(
@ -130,8 +153,17 @@ public class TestFSTCompletion extends LuceneTestCase {
}
public void testFullMatchList() throws Exception {
// one/0.0 is returned first because it's an exact match.
assertMatchEquals(
completion.lookup(stringToCharSequence("one"), Integer.MAX_VALUE),
"one/0.0",
"oneness/1.0",
"onerous/1.0",
"onesimus/1.0");
// full sorted order by weight+alphabetical.
assertMatchEquals(
completion.lookup(stringToCharSequence("on"), Integer.MAX_VALUE),
"oneness/1.0",
"onerous/1.0",
"onesimus/1.0",