LUCENE-9825: Hunspell: reverse the "words" trie for faster word lookup/suggestions

This commit is contained in:
Peter Gromov 2021-03-05 14:34:36 +01:00
parent 99a4bbf3a0
commit 4842e0c9ca
7 changed files with 396 additions and 168 deletions

View File

@ -16,7 +16,8 @@
*/
package org.apache.lucene.analysis.hunspell;
import static org.apache.lucene.analysis.hunspell.AffixKind.*;
import static org.apache.lucene.analysis.hunspell.AffixKind.PREFIX;
import static org.apache.lucene.analysis.hunspell.AffixKind.SUFFIX;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
@ -53,8 +54,6 @@ import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.IntsRefBuilder;
@ -91,14 +90,8 @@ public class Dictionary {
*/
ArrayList<AffixCondition> patterns = new ArrayList<>();
/**
* The entries in the .dic file, mapping to their set of flags. the fst output is the ordinal list
* for flagLookup.
*/
FST<IntsRef> words;
/** A Bloom filter over {@link #words} to avoid unnecessary expensive FST traversals */
FixedBitSet wordHashes;
/** The entries in the .dic file, mapping to their set of flags */
WordStorage words;
/**
* The list of unique flagsets (wordforms). theoretically huge, but practically small (for Polish
@ -257,9 +250,8 @@ public class Dictionary {
// read dictionary entries
IndexOutput unsorted = tempDir.createTempOutput(tempFileNamePrefix, "dat", IOContext.DEFAULT);
int wordCount = mergeDictionaries(dictionaries, decoder, unsorted);
wordHashes = new FixedBitSet(Integer.highestOneBit(wordCount * 10));
String sortedFile = sortWordsOffline(tempDir, tempFileNamePrefix, unsorted);
words = readSortedDictionaries(tempDir, sortedFile, flagEnumerator);
words = readSortedDictionaries(tempDir, sortedFile, flagEnumerator, wordCount);
flagLookup = flagEnumerator.finish();
aliases = null; // no longer needed
morphAliases = null; // no longer needed
@ -272,36 +264,27 @@ public class Dictionary {
/** Looks up Hunspell word forms from the dictionary */
IntsRef lookupWord(char[] word, int offset, int length) {
int hash = CharsRef.stringHashCode(word, offset, length);
if (!wordHashes.get(Math.abs(hash) % wordHashes.length())) {
return null;
}
return lookup(words, word, offset, length);
return words.lookupWord(word, offset, length);
}
// only for testing
IntsRef lookupPrefix(char[] word) {
return lookup(prefixes, word, 0, word.length);
return lookup(prefixes, word);
}
// only for testing
IntsRef lookupSuffix(char[] word) {
return lookup(suffixes, word, 0, word.length);
return lookup(suffixes, word);
}
IntsRef lookup(FST<IntsRef> fst, char[] word, int offset, int length) {
if (fst == null) {
return null;
}
private IntsRef lookup(FST<IntsRef> fst, char[] word) {
final FST.BytesReader bytesReader = fst.getBytesReader();
final FST.Arc<IntsRef> arc = fst.getFirstArc(new FST.Arc<>());
// Accumulate output as we go
IntsRef output = fst.outputs.getNoOutput();
int l = offset + length;
for (int i = offset, cp; i < l; i += Character.charCount(cp)) {
cp = Character.codePointAt(word, i, l);
for (int i = 0, cp; i < word.length; i += Character.charCount(cp)) {
cp = Character.codePointAt(word, i, word.length);
output = nextArc(fst, arc, bytesReader, output, cp);
if (output == null) {
return null;
@ -1134,13 +1117,13 @@ public class Dictionary {
return sorted;
}
private FST<IntsRef> readSortedDictionaries(
Directory tempDir, String sorted, FlagEnumerator flags) throws IOException {
private WordStorage readSortedDictionaries(
Directory tempDir, String sorted, FlagEnumerator flags, int wordCount) throws IOException {
boolean success = false;
Map<String, Integer> morphIndices = new HashMap<>();
EntryGrouper grouper = new EntryGrouper(flags);
WordStorage.Builder builder = new WordStorage.Builder(wordCount, hasCustomMorphData, flags);
try (ByteSequencesReader reader =
new ByteSequencesReader(tempDir.openChecksumInput(sorted, IOContext.READONCE), sorted)) {
@ -1180,6 +1163,8 @@ public class Dictionary {
entry = line.substring(0, flagSep);
}
if (entry.isEmpty()) continue;
int morphDataID = 0;
if (end + 1 < line.length()) {
List<String> morphFields = readMorphFields(entry, line.substring(end + 1));
@ -1189,14 +1174,12 @@ public class Dictionary {
}
}
wordHashes.set(Math.abs(entry.hashCode()) % wordHashes.length());
grouper.add(entry, wordForm, morphDataID);
builder.add(entry, wordForm, morphDataID);
}
// finalize last entry
grouper.flushGroup();
success = true;
return grouper.words.compile();
return builder.build();
} finally {
if (success) {
tempDir.deleteFile(sorted);
@ -1275,76 +1258,6 @@ public class Dictionary {
return word[0] == 'İ' && !alternateCasing;
}
private class EntryGrouper {
final FSTCompiler<IntsRef> words =
new FSTCompiler<>(FST.INPUT_TYPE.BYTE4, IntSequenceOutputs.getSingleton());
private final List<char[]> group = new ArrayList<>();
private final List<Integer> morphDataIDs = new ArrayList<>();
private final IntsRefBuilder scratchInts = new IntsRefBuilder();
private String currentEntry = null;
private final FlagEnumerator flagEnumerator;
EntryGrouper(FlagEnumerator flagEnumerator) {
this.flagEnumerator = flagEnumerator;
}
void add(String entry, char[] flags, int morphDataID) throws IOException {
if (!entry.equals(currentEntry)) {
if (currentEntry != null) {
if (entry.compareTo(currentEntry) < 0) {
throw new IllegalArgumentException("out of order: " + entry + " < " + currentEntry);
}
flushGroup();
}
currentEntry = entry;
}
group.add(flags);
if (hasCustomMorphData) {
morphDataIDs.add(morphDataID);
}
}
void flushGroup() throws IOException {
IntsRefBuilder currentOrds = new IntsRefBuilder();
boolean hasNonHidden = false;
for (char[] flags : group) {
if (!hasHiddenFlag(flags)) {
hasNonHidden = true;
break;
}
}
for (int i = 0; i < group.size(); i++) {
char[] flags = group.get(i);
if (hasNonHidden && hasHiddenFlag(flags)) {
continue;
}
currentOrds.append(flagEnumerator.add(flags));
if (hasCustomMorphData) {
currentOrds.append(morphDataIDs.get(i));
}
}
Util.toUTF32(currentEntry, scratchInts);
words.add(scratchInts.get(), currentOrds.get());
group.clear();
morphDataIDs.clear();
}
}
private static boolean hasHiddenFlag(char[] flags) {
for (char flag : flags) {
if (flag == HIDDEN_FLAG) {
return true;
}
}
return false;
}
private void parseAlias(String line) {
String[] ruleArgs = line.split("\\s+");
if (aliases == null) {

View File

@ -20,7 +20,6 @@ import static org.apache.lucene.analysis.hunspell.Dictionary.AFFIX_APPEND;
import static org.apache.lucene.analysis.hunspell.Dictionary.AFFIX_FLAG;
import static org.apache.lucene.analysis.hunspell.Dictionary.AFFIX_STRIP_ORD;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedHashSet;
@ -30,11 +29,8 @@ import java.util.PriorityQueue;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.IntsRefFSTEnum;
import org.apache.lucene.util.fst.IntsRefFSTEnum.InputOutput;
/**
* A class that traverses the entire dictionary and applies affix rules to check if those yield
@ -68,68 +64,44 @@ class GeneratingSuggester {
boolean ignoreTitleCaseRoots = originalCase == WordCase.LOWER && !dictionary.hasLanguage("de");
TrigramAutomaton automaton = new TrigramAutomaton(word);
IntsRefFSTEnum<IntsRef> fstEnum = new IntsRefFSTEnum<>(dictionary.words);
InputOutput<IntsRef> mapping;
while ((mapping = nextKey(fstEnum, word.length() + 4)) != null) {
speller.checkCanceled.run();
dictionary.words.processAllWords(
word.length() + 4,
(rootChars, forms) -> {
speller.checkCanceled.run();
IntsRef key = mapping.input;
assert key.length > 0;
if (Math.abs(key.length - word.length()) > MAX_ROOT_LENGTH_DIFF) {
assert key.length < word.length(); // nextKey takes care of longer keys
continue;
}
assert rootChars.length > 0;
if (Math.abs(rootChars.length - word.length()) > MAX_ROOT_LENGTH_DIFF) {
assert rootChars.length < word.length(); // nextKey takes care of longer keys
return;
}
String root = toString(key);
filterSuitableEntries(root, mapping.output, entries);
if (entries.isEmpty()) continue;
String root = rootChars.toString();
filterSuitableEntries(root, forms, entries);
if (entries.isEmpty()) return;
if (ignoreTitleCaseRoots && WordCase.caseOf(root) == WordCase.TITLE) {
continue;
}
if (ignoreTitleCaseRoots && WordCase.caseOf(rootChars) == WordCase.TITLE) {
return;
}
String lower = dictionary.toLowerCase(root);
int sc =
automaton.ngramScore(lower) - longerWorsePenalty(word, lower) + commonPrefix(word, root);
String lower = dictionary.toLowerCase(root);
int sc =
automaton.ngramScore(lower)
- longerWorsePenalty(word, lower)
+ commonPrefix(word, root);
if (roots.size() == MAX_ROOTS && sc < roots.peek().score) {
continue;
}
if (roots.size() == MAX_ROOTS && sc < roots.peek().score) {
return;
}
entries.forEach(e -> roots.add(new Weighted<>(e, sc)));
while (roots.size() > MAX_ROOTS) {
roots.poll();
}
});
entries.forEach(e -> roots.add(new Weighted<>(e, sc)));
while (roots.size() > MAX_ROOTS) {
roots.poll();
}
}
return roots.stream().sorted().collect(Collectors.toList());
}
private static InputOutput<IntsRef> nextKey(IntsRefFSTEnum<IntsRef> fstEnum, int maxLen) {
try {
InputOutput<IntsRef> next = fstEnum.next();
while (next != null && next.input.length > maxLen) {
int offset = next.input.offset;
int[] ints = ArrayUtil.copyOfSubArray(next.input.ints, offset, offset + maxLen);
if (ints[ints.length - 1] == Integer.MAX_VALUE) {
throw new AssertionError("Too large char");
}
ints[ints.length - 1]++;
next = fstEnum.seekCeil(new IntsRef(ints, 0, ints.length));
}
return next;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private static String toString(IntsRef key) {
char[] chars = new char[key.length];
for (int i = 0; i < key.length; i++) {
chars[i] = (char) key.ints[i + key.offset];
}
return new String(chars);
}
private void filterSuitableEntries(String word, IntsRef forms, List<Root<String>> result) {
result.clear();
for (int i = 0; i < forms.length; i += dictionary.formStep()) {
@ -363,7 +335,7 @@ class GeneratingSuggester {
return result;
}
private static int commonPrefix(String s1, String s2) {
static int commonPrefix(String s1, String s2) {
int i = 0;
int limit = Math.min(s1.length(), s2.length());
while (i < limit && s1.charAt(i) == s2.charAt(i)) {

View File

@ -234,6 +234,8 @@ class ModifyingSuggester {
}
private void tryRemovingChar(String word) {
if (word.length() == 1) return;
for (int i = 0; i < word.length(); i++) {
trySuggestion(word.substring(0, i) + word.substring(i + 1));
}

View File

@ -94,6 +94,10 @@ final class Stemmer {
}
List<CharsRef> list = new ArrayList<>();
if (length == 0) {
return list;
}
RootProcessor processor =
(stem, formID, stemException) -> {
list.add(newStem(stem, stemException));

View File

@ -0,0 +1,338 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.hunspell;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiConsumer;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.ByteArrayDataOutput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.fst.IntSequenceOutputs;
/**
* A data structure for memory-efficient word storage and fast lookup/enumeration. Each dictionary
* entry is stored as:
*
* <ol>
* <li>the last character
* <li>pointer to a similar entry for the prefix (all characters except the last one)
* <li>value data: a list of ints representing word flags and morphological data, and a pointer to
* hash collisions, if any
* </ol>
*
* There's only one entry for each prefix, so it's like a trie/{@link
* org.apache.lucene.util.fst.FST}, but a reversed one: each nodes points to a single previous nodes
* instead of several following ones. For example, "abc" and "abd" point to the same prefix entry
* "ab" which points to "a" which points to 0.<br>
* <br>
* The entries are stored in a contiguous byte array, identified by their offsets, using {@link
* DataOutput#writeVInt} ()} VINT} format for compression.
*/
class WordStorage {
/**
* A map from word's hash (modulo array's length) into the offset of the last entry in {@link
* #wordData} with this hash. Negated, if there's more than one entry with the same hash.
*/
private final int[] hashTable;
/**
* An array of word entries:
*
* <ul>
* <li>VINT: the word's last character
* <li>VINT: pointer to the entry for the same word without the last character. It's relative:
* the difference of this entry's start and the prefix's entry start. 0 for single-character
* entries
* <li>Optional, for non-leaf entries only:
* <ul>
* <li>VINT: the length of the word form data, returned from {@link #lookupWord}
* <li>n * VINT: the word form data
* <li>Optional, for hash-colliding entries only:
* <ul>
* <li>BYTE: 1 if the next collision entry has further collisions, 0 if it's the
* last of the entries with the same hash
* <li>VINT: (relative) pointer to the previous entry with the same hash
* </ul>
* </ul>
* </ul>
*/
private final byte[] wordData;
private WordStorage(int[] hashTable, byte[] wordData) {
this.hashTable = hashTable;
this.wordData = wordData;
}
IntsRef lookupWord(char[] word, int offset, int length) {
assert length > 0;
int hash = Math.abs(CharsRef.stringHashCode(word, offset, length) % hashTable.length);
int pos = hashTable[hash];
if (pos == 0) {
return null;
}
boolean collision = pos < 0;
pos = Math.abs(pos);
char lastChar = word[offset + length - 1];
ByteArrayDataInput in = new ByteArrayDataInput(wordData);
while (true) {
in.setPosition(pos);
char c = (char) in.readVInt();
int prevPos = pos - in.readVInt();
int beforeForms = in.getPosition();
boolean found = c == lastChar && isSameString(word, offset, length - 1, prevPos, in);
if (!collision && !found) {
return null;
}
in.setPosition(beforeForms);
int formLength = in.readVInt();
if (found) {
IntsRef forms = new IntsRef(formLength);
readForms(forms, in, formLength);
return forms;
} else {
skipVInts(in, formLength);
}
collision = in.readByte() == 1;
pos -= in.readVInt();
}
}
private static void skipVInts(ByteArrayDataInput in, int count) {
for (int i = 0; i < count; ) {
if (in.readByte() >= 0) i++;
}
}
/**
* @param processor is invoked for each word. Note that the passed arguments (word and form) are
* reused, so they can be modified in any way, but may not be saved for later by the processor
*/
void processAllWords(int maxLength, BiConsumer<CharsRef, IntsRef> processor) {
CharsRef chars = new CharsRef(maxLength);
IntsRef forms = new IntsRef();
ByteArrayDataInput in = new ByteArrayDataInput(wordData);
for (int pos : hashTable) {
boolean collision = pos < 0;
pos = Math.abs(pos);
while (pos != 0) {
int wordStart = maxLength - 1;
in.setPosition(pos);
chars.chars[wordStart] = (char) in.readVInt();
int prevPos = pos - in.readVInt();
int dataLength = in.readVInt();
if (forms.ints.length < dataLength) {
forms.ints = new int[dataLength];
}
readForms(forms, in, dataLength);
int afterForms = in.getPosition();
while (prevPos != 0 && wordStart > 0) {
in.setPosition(prevPos);
chars.chars[--wordStart] = (char) in.readVInt();
prevPos -= in.readVInt();
}
if (wordStart > 0) {
chars.offset = wordStart;
chars.length = maxLength - wordStart;
processor.accept(chars, forms);
}
if (!collision) {
break;
}
in.setPosition(afterForms);
collision = in.readVInt() == 1;
pos -= in.readVInt();
}
}
}
private boolean isSameString(
char[] word, int offset, int length, int dataPos, ByteArrayDataInput in) {
for (int i = length - 1; i >= 0; i--) {
in.setPosition(dataPos);
char c = (char) in.readVInt();
if (c != word[i + offset]) {
return false;
}
dataPos -= in.readVInt();
if (dataPos == 0) {
return i == 0;
}
}
return length == 0;
}
private void readForms(IntsRef forms, ByteArrayDataInput in, int length) {
for (int i = 0; i < length; i++) {
forms.ints[i] = in.readVInt();
}
forms.length = length;
}
static class Builder {
private final boolean hasCustomMorphData;
private final int[] hashTable;
private byte[] wordData;
private final int[] chainLengths;
private final List<char[]> group = new ArrayList<>();
private final List<Integer> morphDataIDs = new ArrayList<>();
private String currentEntry = null;
private final FlagEnumerator flagEnumerator;
private final ByteArrayDataOutput dataWriter;
int commonPrefixLength, commonPrefixPos;
Builder(int wordCount, boolean hasCustomMorphData, FlagEnumerator flagEnumerator) {
this.flagEnumerator = flagEnumerator;
this.hasCustomMorphData = hasCustomMorphData;
hashTable = new int[wordCount];
wordData = new byte[wordCount * 6];
dataWriter = new ByteArrayDataOutput(wordData);
dataWriter.writeByte((byte) 0); // zero index is root, contains nothing
chainLengths = new int[hashTable.length];
}
void add(String entry, char[] flags, int morphDataID) throws IOException {
if (!entry.equals(currentEntry)) {
if (currentEntry != null) {
if (entry.compareTo(currentEntry) < 0) {
throw new IllegalArgumentException("out of order: " + entry + " < " + currentEntry);
}
int pos = flushGroup();
commonPrefixLength = GeneratingSuggester.commonPrefix(currentEntry, entry);
ByteArrayDataInput in = new ByteArrayDataInput(wordData);
in.setPosition(pos);
for (int i = currentEntry.length() - 1; i >= commonPrefixLength; i--) {
char c = (char) in.readVInt();
assert c == currentEntry.charAt(i);
pos -= in.readVInt();
in.setPosition(pos);
}
commonPrefixPos = pos;
}
currentEntry = entry;
}
group.add(flags);
if (hasCustomMorphData) {
morphDataIDs.add(morphDataID);
}
}
private int flushGroup() throws IOException {
IntsRefBuilder currentOrds = new IntsRefBuilder();
boolean hasNonHidden = false;
for (char[] flags : group) {
if (!hasHiddenFlag(flags)) {
hasNonHidden = true;
break;
}
}
for (int i = 0; i < group.size(); i++) {
char[] flags = group.get(i);
if (hasNonHidden && hasHiddenFlag(flags)) {
continue;
}
currentOrds.append(flagEnumerator.add(flags));
if (hasCustomMorphData) {
currentOrds.append(morphDataIDs.get(i));
}
}
int lastPos = commonPrefixPos;
for (int i = commonPrefixLength; i < currentEntry.length() - 1; i++) {
int pos = dataWriter.getPosition();
ensureArraySize(0, false);
dataWriter.writeVInt(currentEntry.charAt(i));
dataWriter.writeVInt(pos - lastPos);
lastPos = pos;
}
int pos = dataWriter.getPosition();
int hash = Math.abs(currentEntry.hashCode() % hashTable.length);
int collision = hashTable[hash];
hashTable[hash] = collision == 0 ? pos : -pos;
if (++chainLengths[hash] > 20) {
throw new RuntimeException(
"Too many collisions, please report this to dev@lucene.apache.org");
}
ensureArraySize(currentOrds.length(), collision != 0);
dataWriter.writeVInt(currentEntry.charAt(currentEntry.length() - 1));
dataWriter.writeVInt(pos - lastPos);
IntSequenceOutputs.getSingleton().write(currentOrds.get(), dataWriter);
if (collision != 0) {
dataWriter.writeByte(collision < 0 ? (byte) 1 : 0);
dataWriter.writeVInt(pos - Math.abs(collision));
}
group.clear();
morphDataIDs.clear();
return pos;
}
private void ensureArraySize(int valueLength, boolean hasCollision) {
int pos = dataWriter.getPosition();
int maxEntrySize = 8 + 4 * (valueLength + 1) + (hasCollision ? 5 : 0);
while (wordData.length < pos + maxEntrySize) {
wordData = ArrayUtil.grow(wordData);
dataWriter.reset(wordData, pos, wordData.length - pos);
}
}
private static boolean hasHiddenFlag(char[] flags) {
for (char flag : flags) {
if (flag == Dictionary.HIDDEN_FLAG) {
return true;
}
}
return false;
}
WordStorage build() throws IOException {
flushGroup();
return new WordStorage(
hashTable, ArrayUtil.copyOfSubArray(wordData, 0, dataWriter.getPosition()));
}
}
}

View File

@ -160,8 +160,7 @@ public class TestAllDictionaries extends LuceneTestCase {
try {
Dictionary dic = loadDictionary(aff);
totalMemory.addAndGet(RamUsageTester.sizeOf(dic));
totalWords.addAndGet(
RamUsageTester.sizeOf(dic.words) + RamUsageTester.sizeOf(dic.wordHashes));
totalWords.addAndGet(RamUsageTester.sizeOf(dic.words));
System.out.println(aff + "\t" + memoryUsageSummary(dic));
} catch (Throwable e) {
failures.add(aff);

View File

@ -57,12 +57,12 @@ public class TestPerformance extends LuceneTestCase {
@Test
public void en() throws Exception {
checkAnalysisPerformance("en", 1_000_000);
checkAnalysisPerformance("en", 1_200_000);
}
@Test
public void en_suggest() throws Exception {
checkSuggestionPerformance("en", 1_200);
checkSuggestionPerformance("en", 3_000);
}
@Test
@ -72,7 +72,7 @@ public class TestPerformance extends LuceneTestCase {
@Test
public void de_suggest() throws Exception {
checkSuggestionPerformance("de", 55);
checkSuggestionPerformance("de", 60);
}
@Test