LUCENE-9824: Hunspell suggestions: speed up ngram score calculation for each dictionary entry (#2457)

This commit is contained in:
Peter Gromov 2021-03-05 16:00:02 +01:00 committed by GitHub
parent 6e67b9f959
commit 99a4bbf3a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 212 additions and 32 deletions

View File

@ -1025,6 +1025,7 @@ public class Dictionary {
assert morphSep > 0;
assert morphSep > flagSep;
int sep = flagSep < 0 ? morphSep : flagSep;
if (sep == 0) return 0;
CharSequence toWrite;
String beforeSep = line.substring(0, sep);

View File

@ -23,7 +23,6 @@ import static org.apache.lucene.analysis.hunspell.Dictionary.AFFIX_STRIP_ORD;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
@ -67,7 +66,7 @@ class GeneratingSuggester {
PriorityQueue<Weighted<Root<String>>> roots = new PriorityQueue<>(natural.reversed());
List<Root<String>> entries = new ArrayList<>();
boolean ignoreTitleCaseRoots = originalCase == WordCase.LOWER && !dictionary.hasLanguage("de");
EnumSet<NGramOptions> options = EnumSet.of(NGramOptions.LONGER_WORSE);
TrigramAutomaton automaton = new TrigramAutomaton(word);
IntsRefFSTEnum<IntsRef> fstEnum = new IntsRefFSTEnum<>(dictionary.words);
InputOutput<IntsRef> mapping;
@ -75,6 +74,7 @@ class GeneratingSuggester {
speller.checkCanceled.run();
IntsRef key = mapping.input;
assert key.length > 0;
if (Math.abs(key.length - word.length()) > MAX_ROOT_LENGTH_DIFF) {
assert key.length < word.length(); // nextKey takes care of longer keys
continue;
@ -89,7 +89,8 @@ class GeneratingSuggester {
}
String lower = dictionary.toLowerCase(root);
int sc = ngram(3, word, lower, options) + commonPrefix(word, root);
int sc =
automaton.ngramScore(lower) - longerWorsePenalty(word, lower) + commonPrefix(word, root);
if (roots.size() == MAX_ROOTS && sc < roots.peek().score) {
continue;
@ -152,7 +153,7 @@ class GeneratingSuggester {
for (String guess : expandRoot(weighted.word, misspelled)) {
String lower = dictionary.toLowerCase(guess);
int sc =
ngram(misspelled.length(), misspelled, lower, EnumSet.of(NGramOptions.ANY_MISMATCH))
anyMismatchNgram(misspelled.length(), misspelled, lower, false)
+ commonPrefix(misspelled, guess);
if (sc > thresh) {
expanded.add(new Weighted<>(guess, sc));
@ -173,7 +174,7 @@ class GeneratingSuggester {
mw[k] = '*';
}
thresh += ngram(word.length(), word, new String(mw), EnumSet.of(NGramOptions.ANY_MISMATCH));
thresh += anyMismatchNgram(word.length(), word, new String(mw), false);
}
return thresh / 3 - 1;
}
@ -314,16 +315,14 @@ class GeneratingSuggester {
break;
}
int re =
ngram(2, word, lower, EnumSet.of(NGramOptions.ANY_MISMATCH, NGramOptions.WEIGHTED))
+ ngram(2, lower, word, EnumSet.of(NGramOptions.ANY_MISMATCH, NGramOptions.WEIGHTED));
int re = anyMismatchNgram(2, word, lower, true) + anyMismatchNgram(2, lower, word, true);
int score =
2 * lcs(word, lower)
- Math.abs(word.length() - lower.length())
+ commonCharacterPositionScore(word, lower)
+ commonPrefix(word, lower)
+ ngram(4, word, lower, EnumSet.of(NGramOptions.ANY_MISMATCH))
+ anyMismatchNgram(4, word, lower, false)
+ re
+ (re < (word.length() + lower.length()) * fact ? -1000 : 0);
bySimilarity.add(new Weighted<>(guess, score));
@ -374,14 +373,9 @@ class GeneratingSuggester {
}
// generate an n-gram score comparing s1 and s2
private static int ngram(int n, String s1, String s2, EnumSet<NGramOptions> opt) {
int score = 0;
static int ngramScore(int n, String s1, String s2, boolean weighted) {
int l1 = s1.length();
int l2 = s2.length();
if (l2 == 0) {
return 0;
}
int score = 0;
int[] lastStarts = new int[l1];
for (int j = 1; j <= n; j++) {
int ns = 0;
@ -394,7 +388,7 @@ class GeneratingSuggester {
continue;
}
}
if (opt.contains(NGramOptions.WEIGHTED)) {
if (weighted) {
ns--;
if (i == 0 || i == l1 - j) {
ns--; // side weight
@ -402,19 +396,21 @@ class GeneratingSuggester {
}
}
score = score + ns;
if (ns < 2 && !opt.contains(NGramOptions.WEIGHTED)) {
if (ns < 2 && !weighted) {
break;
}
}
return score;
}
int ns = 0;
if (opt.contains(NGramOptions.LONGER_WORSE)) {
ns = (l2 - l1) - 2;
// NGRAM_LONGER_WORSE flag in Hunspell
private static int longerWorsePenalty(String s1, String s2) {
return Math.max((s2.length() - s1.length()) - 2, 0);
}
if (opt.contains(NGramOptions.ANY_MISMATCH)) {
ns = Math.abs(l2 - l1) - 2;
}
return score - Math.max(ns, 0);
// NGRAM_ANY_MISMATCH flag in Hunspell
private static int anyMismatchNgram(int n, String s1, String s2, boolean weighted) {
return ngramScore(n, s1, s2, weighted) - Math.max(Math.abs(s2.length() - s1.length()) - 2, 0);
}
private static int indexOfSubstring(
@ -471,12 +467,6 @@ class GeneratingSuggester {
return commonScore;
}
private enum NGramOptions {
WEIGHTED,
LONGER_WORSE,
ANY_MISMATCH
}
private static class Weighted<T extends Comparable<T>> implements Comparable<Weighted<T>> {
final T word;
final int score;

View File

@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.hunspell;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.CharacterRunAutomaton;
import org.apache.lucene.util.automaton.Operations;
/**
* An automaton allowing to achieve the same results as non-weighted {@link
* GeneratingSuggester#ngramScore}, but faster (in O(s2.length) time).
*/
class TrigramAutomaton {
private static final int N = 3;
private final CharacterRunAutomaton automaton;
private final int[] state2Score;
private final FixedBitSet countedSubstrings;
TrigramAutomaton(String s1) {
Map<String, Integer> substringCounts = new HashMap<>();
Automaton.Builder builder = new Automaton.Builder(s1.length() * N, s1.length() * N);
int initialState = builder.createState();
for (int start = 0; start < s1.length(); start++) {
int limit = Math.min(s1.length(), start + N);
for (int end = start + 1; end <= limit; end++) {
substringCounts.merge(s1.substring(start, end), 1, Integer::sum);
}
int state = initialState;
for (int i = start; i < limit; i++) {
int next = builder.createState();
builder.addTransition(state, next, s1.charAt(i));
state = next;
}
}
automaton =
new CharacterRunAutomaton(
Operations.determinize(builder.finish(), Operations.DEFAULT_MAX_DETERMINIZED_STATES));
state2Score = new int[automaton.getSize()];
for (Map.Entry<String, Integer> entry : substringCounts.entrySet()) {
int state = runAutomatonOnStringChars(entry.getKey());
assert state2Score[state] == 0;
state2Score[state] = entry.getValue();
}
countedSubstrings = new FixedBitSet(state2Score.length);
}
private int runAutomatonOnStringChars(String s) {
int state = 0;
for (int i = 0; i < s.length(); i++) {
state = automaton.step(state, s.charAt(i));
}
return state;
}
int ngramScore(String s2) {
countedSubstrings.clear(0, countedSubstrings.length());
int score1 = 0, score2 = 0, score3 = 0; // scores for substrings of length 1, 2 and 3
// states of running the automaton on substrings [i-1, i) and [i-2, i)
int state1 = -1, state2 = -1;
int length = s2.length();
for (int i = 0; i < length; i++) {
char c = s2.charAt(i);
int state3 = state2 <= 0 ? 0 : automaton.step(state2, c);
if (state3 > 0) {
score3 += substringScore(state3, countedSubstrings);
}
state2 = state1 <= 0 ? 0 : automaton.step(state1, c);
if (state2 > 0) {
score2 += substringScore(state2, countedSubstrings);
}
state1 = automaton.step(0, c);
if (state1 > 0) {
score1 += substringScore(state1, countedSubstrings);
}
}
int score = score1;
if (score1 >= 2) {
score += score2;
if (score2 >= 2) {
score += score3;
}
}
return score;
}
private int substringScore(int state, FixedBitSet countedSubstrings) {
if (countedSubstrings.getAndSet(state)) return 0;
int score = state2Score[state];
assert score > 0;
return score;
}
}

View File

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.hunspell;
import java.util.stream.Collectors;
import org.apache.lucene.util.LuceneTestCase;
public class TestTrigramAutomaton extends LuceneTestCase {
public void testSameScore() {
checkScores("look", "looked");
checkScores("look", "cool");
checkScores("abracadabra", "abraham");
}
public void testRandomized() {
String[] alphabet = {
"a",
"b",
"c",
"aa",
"ab",
"abc",
"ccc",
"\uD800\uDFD1",
"\uD800\uDFD2",
"\uD800\uDFD2\uD800\uDFD1",
"\uD800\uDFD2\uD800\uDFD2"
};
for (int i = 0; i < 100; i++) {
checkScores(randomConcatenation(alphabet), randomConcatenation(alphabet));
}
}
private String randomConcatenation(String[] alphabet) {
return random()
.ints(0, alphabet.length)
.limit(random().nextInt(20) + 1)
.mapToObj(i -> alphabet[i])
.collect(Collectors.joining());
}
private void checkScores(String s1, String s2) {
String message = "Fails: checkScores(\"" + s1 + "\", \"" + s2 + "\")";
assertEquals(
message,
GeneratingSuggester.ngramScore(3, s1, s2, false),
new TrigramAutomaton(s1).ngramScore(s2));
assertEquals(
message,
GeneratingSuggester.ngramScore(3, s2, s1, false),
new TrigramAutomaton(s2).ngramScore(s1));
}
}