mirror of https://github.com/apache/lucene.git
LUCENE-10539: Return a stream of completions from FSTCompletion. (#844)
This commit is contained in:
parent
75aadb9589
commit
05de9085ce
|
@ -76,6 +76,8 @@ API Changes
|
|||
New Features
|
||||
---------------------
|
||||
|
||||
* LUCENE-10539: Return a stream of completions from FSTCompletion. (Dawid Weiss)
|
||||
|
||||
* LUCENE-10385: Implement Weight#count on IndexSortSortedNumericDocValuesRangeQuery
|
||||
to speed up computing the number of hits when possible. (Lu Xugang, Luca Cavanna, Adrien Grand)
|
||||
|
||||
|
|
|
@ -17,8 +17,19 @@
|
|||
package org.apache.lucene.search.suggest.fst;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import org.apache.lucene.util.*;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Spliterator;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
import java.util.stream.Stream;
|
||||
import java.util.stream.StreamSupport;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.fst.FST;
|
||||
import org.apache.lucene.util.fst.FST.Arc;
|
||||
|
||||
|
@ -53,7 +64,7 @@ public class FSTCompletion {
|
|||
return utf8.utf8ToString() + "/" + bucket;
|
||||
}
|
||||
|
||||
/** @see BytesRef#compareTo(BytesRef) */
|
||||
/** Completions are equal when their {@link #utf8} images are equal (bucket is not compared). */
|
||||
@Override
|
||||
public int compareTo(Completion o) {
|
||||
return this.utf8.compareTo(o.utf8);
|
||||
|
@ -184,110 +195,174 @@ public class FSTCompletion {
|
|||
return EMPTY_RESULT;
|
||||
}
|
||||
|
||||
if (!higherWeightsFirst && rootArcs.length > 1) {
|
||||
// We could emit a warning here (?). An optimal strategy for
|
||||
// alphabetically sorted
|
||||
// suggestions would be to add them with a constant weight -- this saves
|
||||
// unnecessary
|
||||
// traversals and sorting.
|
||||
return lookup(key).sorted().limit(num).collect(Collectors.toList());
|
||||
} else {
|
||||
return lookup(key).limit(num).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Lookup suggestions to <code>key</code> and return a stream of matching completions. The stream
|
||||
* fetches completions dynamically - it can be filtered and limited to acquire the desired number
|
||||
* of completions without collecting all of them.
|
||||
*
|
||||
* @param key The prefix to which suggestions should be sought.
|
||||
* @return Returns the suggestions
|
||||
*/
|
||||
public Stream<Completion> lookup(CharSequence key) {
|
||||
if (key.length() == 0 || automaton == null) {
|
||||
return Stream.empty();
|
||||
}
|
||||
|
||||
try {
|
||||
BytesRef keyUtf8 = new BytesRef(key);
|
||||
if (!higherWeightsFirst && rootArcs.length > 1) {
|
||||
// We could emit a warning here (?). An optimal strategy for
|
||||
// alphabetically sorted
|
||||
// suggestions would be to add them with a constant weight -- this saves
|
||||
// unnecessary
|
||||
// traversals and sorting.
|
||||
return lookupSortedAlphabetically(keyUtf8, num);
|
||||
} else {
|
||||
return lookupSortedByWeight(keyUtf8, num, false);
|
||||
}
|
||||
return lookupSortedByWeight(new BytesRef(key));
|
||||
} catch (IOException e) {
|
||||
// Should never happen, but anyway.
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Lookup suggestions sorted alphabetically <b>if weights are not constant</b>. This is a
|
||||
* workaround: in general, use constant weights for alphabetically sorted result.
|
||||
*/
|
||||
private List<Completion> lookupSortedAlphabetically(BytesRef key, int num) throws IOException {
|
||||
// Greedily get num results from each weight branch.
|
||||
List<Completion> res = lookupSortedByWeight(key, num, true);
|
||||
/** Lookup suggestions sorted by weight (descending order). */
|
||||
private Stream<Completion> lookupSortedByWeight(BytesRef key) throws IOException {
|
||||
|
||||
// Sort and trim.
|
||||
Collections.sort(res);
|
||||
if (res.size() > num) {
|
||||
res = res.subList(0, num);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Lookup suggestions sorted by weight (descending order).
|
||||
*
|
||||
* @param collectAll If <code>true</code>, the routine terminates immediately when <code>num
|
||||
* </code> suggestions have been collected. If <code>false</code>, it will collect suggestions
|
||||
* from all weight arcs (needed for {@link #lookupSortedAlphabetically}.
|
||||
*/
|
||||
private ArrayList<Completion> lookupSortedByWeight(BytesRef key, int num, boolean collectAll)
|
||||
throws IOException {
|
||||
// Don't overallocate the results buffers. This also serves the purpose of
|
||||
// allowing the user of this class to request all matches using Integer.MAX_VALUE as
|
||||
// the number of results.
|
||||
final ArrayList<Completion> res = new ArrayList<>(Math.min(10, num));
|
||||
|
||||
final BytesRef output = BytesRef.deepCopyOf(key);
|
||||
for (int i = 0; i < rootArcs.length; i++) {
|
||||
final FST.Arc<Object> rootArc = rootArcs[i];
|
||||
final FST.Arc<Object> arc = new FST.Arc<>().copyFrom(rootArc);
|
||||
|
||||
// Descend into the automaton using the key as prefix.
|
||||
if (descendWithPrefix(arc, key)) {
|
||||
// A subgraph starting from the current node has the completions
|
||||
// of the key prefix. The arc we're at is the last key's byte,
|
||||
// so we will collect it too.
|
||||
output.length = key.length - 1;
|
||||
if (collect(res, num, rootArc.label(), output, arc) && !collectAll) {
|
||||
// We have enough suggestions to return immediately. Keep on looking
|
||||
// for an
|
||||
// exact match, if requested.
|
||||
if (exactFirst) {
|
||||
if (!checkExistingAndReorder(res, key)) {
|
||||
int exactMatchBucket = getExactMatchStartingFromRootArc(i, key);
|
||||
if (exactMatchBucket != -1) {
|
||||
// Insert as the first result and truncate at num.
|
||||
while (res.size() >= num) {
|
||||
res.remove(res.size() - 1);
|
||||
}
|
||||
res.add(0, new Completion(key, exactMatchBucket));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Look for an exact match first.
|
||||
Completion exactCompletion;
|
||||
if (exactFirst) {
|
||||
Completion c = null;
|
||||
for (int i = 0; i < rootArcs.length; i++) {
|
||||
int exactMatchBucket = getExactMatchStartingFromRootArc(i, key);
|
||||
if (exactMatchBucket != -1) {
|
||||
// root arcs are sorted by decreasing weight so any first exact match will always win.
|
||||
c = new Completion(key, exactMatchBucket);
|
||||
break;
|
||||
}
|
||||
}
|
||||
exactCompletion = c;
|
||||
} else {
|
||||
exactCompletion = null;
|
||||
}
|
||||
return res;
|
||||
|
||||
Stream<Completion> stream =
|
||||
IntStream.range(0, rootArcs.length)
|
||||
.boxed()
|
||||
.flatMap(
|
||||
i -> {
|
||||
try {
|
||||
final FST.Arc<Object> rootArc = rootArcs[i];
|
||||
final FST.Arc<Object> arc = new FST.Arc<>().copyFrom(rootArc);
|
||||
if (descendWithPrefix(arc, key)) {
|
||||
// A subgraph starting from the current node has the completions
|
||||
// of the key prefix. The arc we're at is the last key's byte,
|
||||
// so we will collect it too.
|
||||
final BytesRef output = BytesRef.deepCopyOf(key);
|
||||
output.length = key.length;
|
||||
return completionStream(output, rootArc.label(), arc);
|
||||
} else {
|
||||
return Stream.empty();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
});
|
||||
|
||||
// if requested, return the exact completion first and omit it in any further completions.
|
||||
if (exactFirst && exactCompletion != null) {
|
||||
stream =
|
||||
Stream.concat(
|
||||
Stream.of(exactCompletion),
|
||||
stream.filter(completion -> exactCompletion.compareTo(completion) != 0));
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the list of
|
||||
* {@link org.apache.lucene.search.suggest.Lookup.LookupResult}s already has a
|
||||
* <code>key</code>. If so, reorders that
|
||||
* {@link org.apache.lucene.search.suggest.Lookup.LookupResult} to the first
|
||||
* position.
|
||||
*
|
||||
* @return Returns <code>true<code> if and only if <code>list</code> contained
|
||||
* <code>key</code>.
|
||||
*/
|
||||
private boolean checkExistingAndReorder(ArrayList<Completion> list, BytesRef key) {
|
||||
// We assume list does not have duplicates (because of how the FST is created).
|
||||
for (int i = list.size(); --i >= 0; ) {
|
||||
if (key.equals(list.get(i).utf8)) {
|
||||
// Key found. Unless already at i==0, remove it and push up front so
|
||||
// that the ordering
|
||||
// remains identical with the exception of the exact match.
|
||||
list.add(0, list.remove(i));
|
||||
return true;
|
||||
/** Return a stream of all completions starting from the provided arc. */
|
||||
private Stream<? extends Completion> completionStream(
|
||||
BytesRef output, int bucket, Arc<Object> fromArc) throws IOException {
|
||||
|
||||
FST.BytesReader fstReader = automaton.getBytesReader();
|
||||
|
||||
class State {
|
||||
Arc<Object> arc;
|
||||
int outputLength;
|
||||
|
||||
State(Arc<Object> arc, int outputLength) throws IOException {
|
||||
this.arc = automaton.readFirstTargetArc(arc, new Arc<>(), fstReader);
|
||||
this.outputLength = outputLength;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
||||
ArrayDeque<State> states = new ArrayDeque<>();
|
||||
states.addLast(new State(fromArc, output.length));
|
||||
|
||||
return StreamSupport.stream(
|
||||
new Spliterator<>() {
|
||||
@Override
|
||||
public boolean tryAdvance(Consumer<? super Completion> action) {
|
||||
try {
|
||||
while (!states.isEmpty()) {
|
||||
var state = states.peekLast();
|
||||
output.length = state.outputLength;
|
||||
var arc = state.arc;
|
||||
var arcLabel = arc.label();
|
||||
|
||||
if (arcLabel == FST.END_LABEL) {
|
||||
Completion completion = new Completion(output, bucket);
|
||||
action.accept(completion);
|
||||
|
||||
if (arc.isLast()) {
|
||||
states.removeLast();
|
||||
} else {
|
||||
automaton.readNextArc(arc, fstReader);
|
||||
}
|
||||
|
||||
return true;
|
||||
} else {
|
||||
assert output.offset == 0;
|
||||
if (output.length == output.bytes.length) {
|
||||
output.bytes = ArrayUtil.grow(output.bytes);
|
||||
}
|
||||
output.bytes[output.length++] = (byte) arcLabel;
|
||||
|
||||
State newState = new State(arc, output.length);
|
||||
|
||||
if (arc.isLast()) {
|
||||
states.removeLast();
|
||||
} else {
|
||||
automaton.readNextArc(arc, fstReader);
|
||||
}
|
||||
|
||||
states.addLast(newState);
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Spliterator<Completion> trySplit() {
|
||||
// Don't try to split.
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long estimateSize() {
|
||||
return Long.MAX_VALUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int characteristics() {
|
||||
return Spliterator.NONNULL | Spliterator.ORDERED;
|
||||
}
|
||||
},
|
||||
false);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -312,41 +387,6 @@ public class FSTCompletion {
|
|||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursive collect lookup results from the automaton subgraph starting at <code>arc</code>.
|
||||
*
|
||||
* @param num Maximum number of results needed (early termination).
|
||||
*/
|
||||
private boolean collect(
|
||||
List<Completion> res, int num, int bucket, BytesRef output, Arc<Object> arc)
|
||||
throws IOException {
|
||||
if (output.length == output.bytes.length) {
|
||||
output.bytes = ArrayUtil.grow(output.bytes);
|
||||
}
|
||||
assert output.offset == 0;
|
||||
output.bytes[output.length++] = (byte) arc.label();
|
||||
FST.BytesReader fstReader = automaton.getBytesReader();
|
||||
automaton.readFirstTargetArc(arc, arc, fstReader);
|
||||
while (true) {
|
||||
if (arc.label() == FST.END_LABEL) {
|
||||
res.add(new Completion(output, bucket));
|
||||
if (res.size() >= num) return true;
|
||||
} else {
|
||||
int save = output.length;
|
||||
if (collect(res, num, bucket, output, new Arc<>().copyFrom(arc))) {
|
||||
return true;
|
||||
}
|
||||
output.length = save;
|
||||
}
|
||||
|
||||
if (arc.isLast()) {
|
||||
break;
|
||||
}
|
||||
automaton.readNextArc(arc, fstReader);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Returns the bucket count (discretization thresholds). */
|
||||
public int getBucketCount() {
|
||||
return rootArcs.length;
|
||||
|
|
|
@ -17,14 +17,23 @@
|
|||
package org.apache.lucene.search.suggest.fst;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
import org.apache.lucene.search.suggest.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Random;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.lucene.search.suggest.Input;
|
||||
import org.apache.lucene.search.suggest.InputArrayIterator;
|
||||
import org.apache.lucene.search.suggest.Lookup.LookupResult;
|
||||
import org.apache.lucene.search.suggest.SuggestRebuildTestUtil;
|
||||
import org.apache.lucene.search.suggest.TestLookupBenchmark;
|
||||
import org.apache.lucene.search.suggest.fst.FSTCompletion.Completion;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.*;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** Unit tests for {@link FSTCompletion}. */
|
||||
public class TestFSTCompletion extends LuceneTestCase {
|
||||
|
@ -81,6 +90,20 @@ public class TestFSTCompletion extends LuceneTestCase {
|
|||
assertMatchEquals(completion.lookup(stringToCharSequence("one"), 2), "one/0.0", "oneness/1.0");
|
||||
}
|
||||
|
||||
public void testCompletionStream() throws Exception {
|
||||
var completions =
|
||||
completion
|
||||
.lookup("fo")
|
||||
.filter(completion -> !completion.utf8.utf8ToString().contains("fourteen"))
|
||||
.sorted(
|
||||
Comparator.comparing(
|
||||
completion -> completion.utf8.utf8ToString().toLowerCase(Locale.ROOT)))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
assertMatchEquals(
|
||||
completions, "foundation/1", "four/0", "fourblah/1", "fourier/0", "fourty/1.0");
|
||||
}
|
||||
|
||||
public void testExactMatchReordering() throws Exception {
|
||||
// Check reordering of exact matches.
|
||||
assertMatchEquals(
|
||||
|
@ -130,8 +153,17 @@ public class TestFSTCompletion extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testFullMatchList() throws Exception {
|
||||
// one/0.0 is returned first because it's an exact match.
|
||||
assertMatchEquals(
|
||||
completion.lookup(stringToCharSequence("one"), Integer.MAX_VALUE),
|
||||
"one/0.0",
|
||||
"oneness/1.0",
|
||||
"onerous/1.0",
|
||||
"onesimus/1.0");
|
||||
|
||||
// full sorted order by weight+alphabetical.
|
||||
assertMatchEquals(
|
||||
completion.lookup(stringToCharSequence("on"), Integer.MAX_VALUE),
|
||||
"oneness/1.0",
|
||||
"onerous/1.0",
|
||||
"onesimus/1.0",
|
||||
|
|
Loading…
Reference in New Issue