LUCENE-3714: add weighted FST suggester impl

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1291020 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Robert Muir 2012-02-19 16:23:05 +00:00
parent 2e07171d1f
commit a519b630ee
13 changed files with 981 additions and 2 deletions

View File

@ -858,6 +858,9 @@ New Features
* LUCENE-3725: Added optional packing to FST building; this uses extra
RAM during building but results in a smaller FST. (Mike McCandless)
* LUCENE-3714: Add top N shortest cost paths search for FST<Long>.
(Robert Muir, Dawid Weiss, Mike McCandless)
Bug fixes

View File

@ -150,6 +150,9 @@ New Features
* LUCENE-3602: Added query time joining under the join module. (Martijn van Groningen, Michael McCandless)
* LUCENE-3714: Add WFSTCompletionLookup suggester that supports more fine-grained
ranking for suggestions. (Mike McCandless, Dawid Weiss, Robert Muir)
API Changes
* LUCENE-3596: DirectoryTaxonomyWriter.openIndexWriter() now takes an

View File

@ -234,7 +234,302 @@ public final class Util {
}
}
}
private static class FSTPath implements Comparable<FSTPath> {
public FST.Arc<Long> arc;
public long cost;
public final IntsRef input = new IntsRef();
public FSTPath(long cost, FST.Arc<Long> arc) {
this.arc = new FST.Arc<Long>().copyFrom(arc);
this.cost = cost;
}
@Override
public String toString() {
return "input=" + input + " cost=" + cost;
}
@Override
public int compareTo(FSTPath other) {
if (cost < other.cost) {
return -1;
} else if (cost > other.cost) {
return 1;
} else {
return input.compareTo(other.input);
}
}
}
private static class TopNSearcher {
private final FST<Long> fst;
private final FST.Arc<Long> fromNode;
private final int topN;
// Set once the queue has filled:
FSTPath bottom = null;
TreeSet<FSTPath> queue = null;
public TopNSearcher(FST<Long> fst, FST.Arc<Long> fromNode, int topN) {
this.fst = fst;
this.topN = topN;
this.fromNode = fromNode;
}
// If back plus this arc is competitive then add to queue:
private void addIfCompetitive(FSTPath path) {
assert queue != null;
long cost = path.cost + path.arc.output;
//System.out.println(" addIfCompetitive bottom=" + bottom + " queue.size()=" + queue.size());
if (bottom != null) {
if (cost > bottom.cost) {
// Doesn't compete
return;
} else if (cost == bottom.cost) {
// Tie break by alpha sort on the input:
path.input.grow(path.input.length+1);
path.input.ints[path.input.length++] = path.arc.label;
final int cmp = bottom.input.compareTo(path.input);
path.input.length--;
assert cmp != 0;
if (cmp < 0) {
// Doesn't compete
return;
}
}
// Competes
} else {
// Queue isn't full yet, so any path we hit competes:
}
final FSTPath newPath = new FSTPath(cost, path.arc);
newPath.input.grow(path.input.length+1);
System.arraycopy(path.input.ints, 0, newPath.input.ints, 0, path.input.length);
newPath.input.ints[path.input.length] = path.arc.label;
newPath.input.length = path.input.length+1;
//System.out.println(" add path=" + newPath);
queue.add(newPath);
if (bottom != null) {
final FSTPath removed = queue.pollLast();
assert removed == bottom;
bottom = queue.last();
//System.out.println(" now re-set bottom: " + bottom + " queue=" + queue);
} else if (queue.size() == topN) {
// Queue just filled up:
bottom = queue.last();
//System.out.println(" now set bottom: " + bottom);
}
}
public MinResult[] search() throws IOException {
//System.out.println(" search topN=" + topN);
final FST.Arc<Long> scratchArc = new FST.Arc<Long>();
final List<MinResult> results = new ArrayList<MinResult>();
final Long NO_OUTPUT = fst.outputs.getNoOutput();
// TODO: we could enable FST to sorting arcs by weight
// as it freezes... can easily do this on first pass
// (w/o requiring rewrite)
// TODO: maybe we should make an FST.INPUT_TYPE.BYTE0.5!?
// (nibbles)
// For each top N path:
while (results.size() < topN) {
//System.out.println("\nfind next path");
FSTPath path;
if (queue == null) {
if (results.size() != 0) {
// Ran out of paths
break;
}
// First pass (top path): start from original fromNode
if (topN > 1) {
queue = new TreeSet<FSTPath>();
}
long minArcCost = Long.MAX_VALUE;
FST.Arc<Long> minArc = null;
path = new FSTPath(0, fromNode);
fst.readFirstTargetArc(fromNode, path.arc);
// Bootstrap: find the min starting arc
while (true) {
long arcScore = path.arc.output;
//System.out.println(" arc=" + (char) path.arc.label + " cost=" + arcScore);
if (arcScore < minArcCost) {
minArcCost = arcScore;
minArc = scratchArc.copyFrom(path.arc);
//System.out.println(" **");
}
if (queue != null) {
addIfCompetitive(path);
}
if (path.arc.isLast()) {
break;
}
fst.readNextArc(path.arc);
}
assert minArc != null;
if (queue != null) {
// Remove top path since we are now going to
// pursue it:
path = queue.pollFirst();
//System.out.println(" remove init path=" + path);
assert path.arc.label == minArc.label;
if (bottom != null && queue.size() == topN-1) {
bottom = queue.last();
//System.out.println(" set init bottom: " + bottom);
}
} else {
path.arc.copyFrom(minArc);
path.input.grow(1);
path.input.ints[0] = minArc.label;
path.input.length = 1;
path.cost = minArc.output;
}
} else {
path = queue.pollFirst();
if (path == null) {
// There were less than topN paths available:
break;
}
}
if (path.arc.label == FST.END_LABEL) {
//System.out.println(" empty string! cost=" + path.cost);
// Empty string!
path.input.length--;
results.add(new MinResult(path.input, path.cost));
continue;
}
if (results.size() == topN-1) {
// Last path -- don't bother w/ queue anymore:
queue = null;
}
//System.out.println(" path: " + path);
// We take path and find its "0 output completion",
// ie, just keep traversing the first arc with
// NO_OUTPUT that we can find, since this must lead
// to the minimum path that completes from
// path.arc.
// For each input letter:
while (true) {
//System.out.println("\n cycle path: " + path);
fst.readFirstTargetArc(path.arc, path.arc);
// For each arc leaving this node:
boolean foundZero = false;
while(true) {
//System.out.println(" arc=" + (char) path.arc.label + " cost=" + path.arc.output);
if (path.arc.output == NO_OUTPUT) {
if (queue == null) {
foundZero = true;
break;
} else if (!foundZero) {
scratchArc.copyFrom(path.arc);
foundZero = true;
} else {
addIfCompetitive(path);
}
} else if (queue != null) {
addIfCompetitive(path);
}
if (path.arc.isLast()) {
break;
}
fst.readNextArc(path.arc);
}
assert foundZero;
if (queue != null) {
// TODO: maybe we can save this copyFrom if we
// are more clever above... eg on finding the
// first NO_OUTPUT arc we'd switch to using
// scratchArc
path.arc.copyFrom(scratchArc);
}
if (path.arc.label == FST.END_LABEL) {
// Add final output:
//System.out.println(" done!: " + path);
results.add(new MinResult(path.input, path.cost + path.arc.output));
break;
} else {
path.input.grow(1+path.input.length);
path.input.ints[path.input.length] = path.arc.label;
path.input.length++;
path.cost += path.arc.output;
}
}
}
return results.toArray(new MinResult[results.size()]);
}
}
// TODO: parameterize the FST type <T> and allow passing in a
// comparator; eg maybe your output is a PairOutput and
// one of the outputs in the pair is monotonic so you
// compare by that
public final static class MinResult implements Comparable<MinResult> {
public final IntsRef input;
public final long output;
public MinResult(IntsRef input, long output) {
this.input = input;
this.output = output;
}
@Override
public int compareTo(MinResult other) {
if (output < other.output) {
return -1;
} else if (output > other.output) {
return 1;
} else {
return input.compareTo(other.input);
}
}
}
/** Starting from node, find the top N min cost (Long
* output) completions to a final node.
*
* <p>NOTE: you must share the outputs when you build the
* FST (pass doShare=true to {@link
* PositiveIntOutputs#getSingleton}). */
public static MinResult[] shortestPaths(FST<Long> fst, FST.Arc<Long> fromNode, int topN) throws IOException {
return new TopNSearcher(fst, fromNode, topN).search();
}
/**
* Dumps an {@link FST} to a GraphViz's <code>dot</code> language description
* for visualization. Example of use:

View File

@ -58,6 +58,7 @@ import org.apache.lucene.util.LuceneTestCase.UseNoMemoryExpensiveCodec;
import org.apache.lucene.util.UnicodeUtil;
import org.apache.lucene.util._TestUtil;
import org.apache.lucene.util.fst.FST.Arc;
import org.apache.lucene.util.fst.FST.BytesReader;
@UseNoMemoryExpensiveCodec
public class TestFSTs extends LuceneTestCase {
@ -1975,6 +1976,119 @@ public class TestFSTs extends LuceneTestCase {
assertEquals(42, arc.output.longValue());
}
public void testShortestPaths() throws Exception {
final PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton(true);
final Builder<Long> builder = new Builder<Long>(FST.INPUT_TYPE.BYTE1, outputs);
final IntsRef scratch = new IntsRef();
builder.add(Util.toIntsRef(new BytesRef("aab"), scratch), 22L);
builder.add(Util.toIntsRef(new BytesRef("aac"), scratch), 7L);
builder.add(Util.toIntsRef(new BytesRef("ax"), scratch), 17L);
final FST<Long> fst = builder.finish();
//Writer w = new OutputStreamWriter(new FileOutputStream("out.dot"));
//Util.toDot(fst, w, false, false);
//w.close();
Util.MinResult[] r = Util.shortestPaths(fst,
fst.getFirstArc(new FST.Arc<Long>()),
3);
assertEquals(3, r.length);
assertEquals(Util.toIntsRef(new BytesRef("aac"), scratch), r[0].input);
assertEquals(7, r[0].output);
assertEquals(Util.toIntsRef(new BytesRef("ax"), scratch), r[1].input);
assertEquals(17, r[1].output);
assertEquals(Util.toIntsRef(new BytesRef("aab"), scratch), r[2].input);
assertEquals(22, r[2].output);
}
public void testShortestPathsRandom() throws Exception {
int numWords = atLeast(1000);
final TreeMap<String,Long> slowCompletor = new TreeMap<String,Long>();
final TreeSet<String> allPrefixes = new TreeSet<String>();
final PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton(true);
final Builder<Long> builder = new Builder<Long>(FST.INPUT_TYPE.BYTE1, outputs);
final IntsRef scratch = new IntsRef();
for (int i = 0; i < numWords; i++) {
String s;
while (true) {
s = _TestUtil.randomSimpleString(random);
if (!slowCompletor.containsKey(s)) {
break;
}
}
for (int j = 1; j < s.length(); j++) {
allPrefixes.add(s.substring(0, j));
}
int weight = _TestUtil.nextInt(random, 1, 100); // weights 1..100
slowCompletor.put(s, (long)weight);
}
for (Map.Entry<String,Long> e : slowCompletor.entrySet()) {
//System.out.println("add: " + e);
builder.add(Util.toIntsRef(new BytesRef(e.getKey()), scratch), e.getValue());
}
final FST<Long> fst = builder.finish();
//System.out.println("SAVE out.dot");
//Writer w = new OutputStreamWriter(new FileOutputStream("out.dot"));
//Util.toDot(fst, w, false, false);
//w.close();
BytesReader reader = fst.getBytesReader(0);
//System.out.println("testing: " + allPrefixes.size() + " prefixes");
for (String prefix : allPrefixes) {
// 1. run prefix against fst, then complete by value
//System.out.println("TEST: " + prefix);
long prefixOutput = 0;
FST.Arc<Long> arc = fst.getFirstArc(new FST.Arc<Long>());
for(int idx=0;idx<prefix.length();idx++) {
if (fst.findTargetArc((int) prefix.charAt(idx), arc, arc, reader) == null) {
fail();
}
prefixOutput += arc.output;
}
final int topN = _TestUtil.nextInt(random, 1, 10);
Util.MinResult[] r = Util.shortestPaths(fst, arc, topN);
// 2. go thru whole treemap (slowCompletor) and check its actually the best suggestion
final List<Util.MinResult> matches = new ArrayList<Util.MinResult>();
// TODO: could be faster... but its slowCompletor for a reason
for (Map.Entry<String,Long> e : slowCompletor.entrySet()) {
if (e.getKey().startsWith(prefix)) {
//System.out.println(" consider " + e.getKey());
matches.add(new Util.MinResult(Util.toIntsRef(new BytesRef(e.getKey().substring(prefix.length())), new IntsRef()),
e.getValue() - prefixOutput));
}
}
assertTrue(matches.size() > 0);
Collections.sort(matches);
if (matches.size() > topN) {
matches.subList(topN, matches.size()).clear();
}
assertEquals(matches.size(), r.length);
for(int hit=0;hit<r.length;hit++) {
//System.out.println(" check hit " + hit);
assertEquals(matches.get(hit).input, r[hit].input);
assertEquals(matches.get(hit).output, r[hit].output);
}
}
}
public void testLargeOutputsOnArrayArcs() throws Exception {
final ByteSequenceOutputs outputs = ByteSequenceOutputs.getSingleton();
final Builder<BytesRef> builder = new Builder<BytesRef>(FST.INPUT_TYPE.BYTE1, outputs);

View File

@ -0,0 +1,280 @@
package org.apache.lucene.search.suggest.fst;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.lucene.search.spell.TermFreqIterator;
import org.apache.lucene.search.suggest.Lookup;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.ByteArrayDataOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.UnicodeUtil;
import org.apache.lucene.util.fst.Builder;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.FST.Arc;
import org.apache.lucene.util.fst.FST.BytesReader;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;
import org.apache.lucene.util.fst.Util.MinResult;
/**
* Suggester based on a weighted FST: it first traverses the prefix,
* then walks the <i>n</i> shortest paths to retrieve top-ranked
* suggestions.
* <p>
* <b>NOTE</b>: Although the {@link TermFreqIterator} API specifies
* floating point weights, input weights should be whole numbers.
* Input weights will be cast to a java integer, and any
* negative, infinite, or NaN values will be rejected.
*
* @see Util#shortestPaths(FST, FST.Arc, int)
* @lucene.experimental
*/
public class WFSTCompletionLookup extends Lookup {
/**
* File name for the automaton.
*
* @see #store(File)
* @see #load(File)
*/
private static final String FILENAME = "wfst.bin";
/**
* FST<Long>, weights are encoded as costs: (Integer.MAX_VALUE-weight)
*/
// NOTE: like FSTSuggester, this is really a WFSA, if you want to
// customize the code to add some output you should use PairOutputs.
private FST<Long> fst = null;
/**
* True if exact match suggestions should always be returned first.
*/
private final boolean exactFirst;
/**
* Calls {@link #WFSTCompletionLookup(boolean) WFSTCompletionLookup(true)}
*/
public WFSTCompletionLookup() {
this(true);
}
/**
* Creates a new suggester.
*
* @param exactFirst <code>true</code> if suggestions that match the
* prefix exactly should always be returned first, regardless
* of score. This has no performance impact, but could result
* in low-quality suggestions.
*/
public WFSTCompletionLookup(boolean exactFirst) {
this.exactFirst = exactFirst;
}
@Override
public void build(TermFreqIterator iterator) throws IOException {
String prefix = getClass().getSimpleName();
File directory = Sort.defaultTempDir();
File tempInput = File.createTempFile(prefix, ".input", directory);
File tempSorted = File.createTempFile(prefix, ".sorted", directory);
Sort.ByteSequencesWriter writer = new Sort.ByteSequencesWriter(tempInput);
Sort.ByteSequencesReader reader = null;
BytesRef scratch = new BytesRef();
boolean success = false;
try {
byte [] buffer = new byte [0];
ByteArrayDataOutput output = new ByteArrayDataOutput(buffer);
while (iterator.hasNext()) {
String key = iterator.next();
UnicodeUtil.UTF16toUTF8(key, 0, key.length(), scratch);
if (scratch.length + 5 >= buffer.length) {
buffer = ArrayUtil.grow(buffer, scratch.length + 5);
}
output.reset(buffer);
output.writeBytes(scratch.bytes, scratch.offset, scratch.length);
output.writeByte((byte)0); // separator: not used, just for sort order
output.writeInt((int)encodeWeight(iterator.freq()));
writer.write(buffer, 0, output.getPosition());
}
writer.close();
new Sort().sort(tempInput, tempSorted);
reader = new Sort.ByteSequencesReader(tempSorted);
PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton(true);
Builder<Long> builder = new Builder<Long>(FST.INPUT_TYPE.BYTE1, outputs);
BytesRef previous = null;
BytesRef suggestion = new BytesRef();
IntsRef scratchInts = new IntsRef();
ByteArrayDataInput input = new ByteArrayDataInput();
while (reader.read(scratch)) {
suggestion.bytes = scratch.bytes;
suggestion.offset = scratch.offset;
suggestion.length = scratch.length - 5; // int + separator
input.reset(scratch.bytes);
input.skipBytes(suggestion.length + 1); // suggestion + separator
long cost = input.readInt();
if (previous == null) {
previous = new BytesRef();
} else if (suggestion.equals(previous)) {
continue; // for duplicate suggestions, the best weight is actually added
}
Util.toIntsRef(suggestion, scratchInts);
builder.add(scratchInts, cost);
previous.copyBytes(suggestion);
}
fst = builder.finish();
success = true;
} finally {
if (success) {
IOUtils.close(reader, writer);
} else {
IOUtils.closeWhileHandlingException(reader, writer);
}
tempInput.delete();
tempSorted.delete();
}
}
@Override
public boolean store(File storeDir) throws IOException {
fst.save(new File(storeDir, FILENAME));
return true;
}
@Override
public boolean load(File storeDir) throws IOException {
this.fst = FST.read(new File(storeDir, FILENAME), PositiveIntOutputs.getSingleton(true));
return true;
}
@Override
public List<LookupResult> lookup(String key, boolean onlyMorePopular, int num) {
assert num > 0;
BytesRef scratch = new BytesRef(key);
int prefixLength = scratch.length;
Arc<Long> arc = new Arc<Long>();
// match the prefix portion exactly
Long prefixOutput = null;
try {
prefixOutput = lookupPrefix(scratch, arc);
} catch (IOException bogus) { throw new RuntimeException(bogus); }
if (prefixOutput == null) {
return Collections.<LookupResult>emptyList();
}
List<LookupResult> results = new ArrayList<LookupResult>(num);
if (exactFirst && arc.isFinal()) {
results.add(new LookupResult(scratch.utf8ToString(), decodeWeight(prefixOutput + arc.nextFinalOutput)));
if (--num == 0) {
return results; // that was quick
}
}
// complete top-N
MinResult completions[] = null;
try {
completions = Util.shortestPaths(fst, arc, num);
} catch (IOException bogus) { throw new RuntimeException(bogus); }
BytesRef suffix = new BytesRef(8);
for (MinResult completion : completions) {
scratch.length = prefixLength;
// append suffix
Util.toBytesRef(completion.input, suffix);
scratch.append(suffix);
results.add(new LookupResult(scratch.utf8ToString(), decodeWeight(prefixOutput + completion.output)));
}
return results;
}
private Long lookupPrefix(BytesRef scratch, Arc<Long> arc) throws /*Bogus*/IOException {
assert 0 == fst.outputs.getNoOutput().longValue();
long output = 0;
BytesReader bytesReader = fst.getBytesReader(0);
fst.getFirstArc(arc);
byte[] bytes = scratch.bytes;
int pos = scratch.offset;
int end = pos + scratch.length;
while (pos < end) {
if (fst.findTargetArc(bytes[pos++] & 0xff, arc, arc, bytesReader) == null) {
return null;
} else {
output += arc.output.longValue();
}
}
return output;
}
@Override
public boolean add(String key, Object value) {
return false; // Not supported.
}
/**
* Returns the weight associated with an input string,
* or null if it does not exist.
*/
@Override
public Float get(String key) {
Arc<Long> arc = new Arc<Long>();
Long result = null;
try {
result = lookupPrefix(new BytesRef(key), arc);
} catch (IOException bogus) { throw new RuntimeException(bogus); }
if (result == null || !arc.isFinal()) {
return null;
} else {
return decodeWeight(result + arc.nextFinalOutput);
}
}
/** cost -> weight */
private static float decodeWeight(long encoded) {
return Integer.MAX_VALUE - encoded;
}
/** weight -> cost */
private static long encodeWeight(float value) {
if (Float.isNaN(value) || Float.isInfinite(value) || value < 0 || value > Integer.MAX_VALUE) {
throw new UnsupportedOperationException("cannot encode value: " + value);
}
return Integer.MAX_VALUE - (int)value;
}
}

View File

@ -32,6 +32,7 @@ import java.util.concurrent.Callable;
import org.apache.lucene.util.*;
import org.apache.lucene.search.suggest.Lookup;
import org.apache.lucene.search.suggest.fst.FSTCompletionLookup;
import org.apache.lucene.search.suggest.fst.WFSTCompletionLookup;
import org.apache.lucene.search.suggest.jaspell.JaspellLookup;
import org.apache.lucene.search.suggest.tst.TSTLookup;
@ -47,7 +48,8 @@ public class LookupBenchmarkTest extends LuceneTestCase {
private final List<Class<? extends Lookup>> benchmarkClasses = Arrays.asList(
JaspellLookup.class,
TSTLookup.class,
FSTCompletionLookup.class);
FSTCompletionLookup.class,
WFSTCompletionLookup.class);
private final static int rounds = 15;
private final static int warmup = 5;
@ -72,6 +74,7 @@ public class LookupBenchmarkTest extends LuceneTestCase {
*/
@BeforeClass
public static void setup() throws Exception {
assert false : "disable assertions before running benchmarks!";
List<TermFreq> input = readTop50KWiki();
Collections.shuffle(input, random);
LookupBenchmarkTest.dictionaryInput = input.toArray(new TermFreq [input.size()]);

View File

@ -0,0 +1,148 @@
package org.apache.lucene.search.suggest.fst;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.TreeSet;
import org.apache.lucene.search.suggest.Lookup.LookupResult;
import org.apache.lucene.search.suggest.TermFreq;
import org.apache.lucene.search.suggest.TermFreqArrayIterator;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util._TestUtil;
public class WFSTCompletionTest extends LuceneTestCase {
public void test() throws Exception {
TermFreq keys[] = new TermFreq[] {
new TermFreq("foo", 50),
new TermFreq("bar", 10),
new TermFreq("barbar", 12),
new TermFreq("barbara", 6)
};
WFSTCompletionLookup suggester = new WFSTCompletionLookup();
suggester.build(new TermFreqArrayIterator(keys));
// top N of 2, but only foo is available
List<LookupResult> results = suggester.lookup("f", false, 2);
assertEquals(1, results.size());
assertEquals("foo", results.get(0).key);
assertEquals(50, results.get(0).value, 0.01F);
// top N of 1 for 'bar': we return this even though barbar is higher
results = suggester.lookup("bar", false, 1);
assertEquals(1, results.size());
assertEquals("bar", results.get(0).key);
assertEquals(10, results.get(0).value, 0.01F);
// top N Of 2 for 'b'
results = suggester.lookup("b", false, 2);
assertEquals(2, results.size());
assertEquals("barbar", results.get(0).key);
assertEquals(12, results.get(0).value, 0.01F);
assertEquals("bar", results.get(1).key);
assertEquals(10, results.get(1).value, 0.01F);
// top N of 3 for 'ba'
results = suggester.lookup("ba", false, 3);
assertEquals(3, results.size());
assertEquals("barbar", results.get(0).key);
assertEquals(12, results.get(0).value, 0.01F);
assertEquals("bar", results.get(1).key);
assertEquals(10, results.get(1).value, 0.01F);
assertEquals("barbara", results.get(2).key);
assertEquals(6, results.get(2).value, 0.01F);
}
public void testRandom() throws Exception {
int numWords = atLeast(1000);
final TreeMap<String,Long> slowCompletor = new TreeMap<String,Long>();
final TreeSet<String> allPrefixes = new TreeSet<String>();
TermFreq[] keys = new TermFreq[numWords];
for (int i = 0; i < numWords; i++) {
String s;
while (true) {
// TODO: would be nice to fix this slowCompletor/comparator to
// use full range, but we might lose some coverage too...
s = _TestUtil.randomSimpleString(random);
if (!slowCompletor.containsKey(s)) {
break;
}
}
for (int j = 1; j < s.length(); j++) {
allPrefixes.add(s.substring(0, j));
}
// we can probably do Integer.MAX_VALUE here, but why worry.
int weight = random.nextInt(1<<24);
slowCompletor.put(s, (long)weight);
keys[i] = new TermFreq(s, (float) weight);
}
WFSTCompletionLookup suggester = new WFSTCompletionLookup(false);
suggester.build(new TermFreqArrayIterator(keys));
for (String prefix : allPrefixes) {
final int topN = _TestUtil.nextInt(random, 1, 10);
List<LookupResult> r = suggester.lookup(prefix, false, topN);
// 2. go thru whole treemap (slowCompletor) and check its actually the best suggestion
final List<LookupResult> matches = new ArrayList<LookupResult>();
// TODO: could be faster... but its slowCompletor for a reason
for (Map.Entry<String,Long> e : slowCompletor.entrySet()) {
if (e.getKey().startsWith(prefix)) {
matches.add(new LookupResult(e.getKey(), (float)e.getValue().longValue()));
}
}
assertTrue(matches.size() > 0);
Collections.sort(matches, new Comparator<LookupResult>() {
public int compare(LookupResult left, LookupResult right) {
int cmp = Float.compare(right.value, left.value);
if (cmp == 0) {
return left.key.compareTo(right.key);
} else {
return cmp;
}
}
});
if (matches.size() > topN) {
matches.subList(topN, matches.size()).clear();
}
assertEquals(matches.size(), r.size());
for(int hit=0;hit<r.size();hit++) {
//System.out.println(" check hit " + hit);
assertEquals(matches.get(hit).key, r.get(hit).key);
assertEquals(matches.get(hit).value, r.get(hit).value, 0f);
}
}
}
}

View File

@ -501,6 +501,9 @@ New Features
* SOLR-3105: ElisionFilterFactory optionally allows the parameter
ignoreCase (default=false). (Robert Muir)
* LUCENE-3714: Add WFSTLookupFactory, a suggester that uses a weighted FST
for more fine-grained suggestions. (Mike McCandless, Dawid Weiss, Robert Muir)
Optimizations
----------------------
* SOLR-1931: Speedup for LukeRequestHandler and admin/schema browser. New parameter

View File

@ -0,0 +1,45 @@
package org.apache.solr.spelling.suggest.fst;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.search.suggest.Lookup;
import org.apache.lucene.search.suggest.fst.*;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.spelling.suggest.LookupFactory;
/**
* Factory for {@link WFSTCompletionLookup}
* @lucene.experimental
*/
public class WFSTLookupFactory extends LookupFactory {
/**
* If <code>true</code>, exact suggestions are returned first, even if they are prefixes
* of other strings in the automaton (possibly with larger weights).
*/
public static final String EXACT_MATCH_FIRST = "exactMatchFirst";
@Override
public Lookup create(NamedList params, SolrCore core) {
boolean exactMatchFirst = params.get(EXACT_MATCH_FIRST) != null
? Boolean.valueOf(params.get(EXACT_MATCH_FIRST).toString())
: true;
return new WFSTCompletionLookup(exactMatchFirst);
}
}

View File

@ -75,6 +75,21 @@
<bool name="exactMatchFirst">true</bool>
</lst>
</searchComponent>
<!-- WFSTLookup suggest component -->
<searchComponent class="solr.SpellCheckComponent" name="suggest_wfst">
<lst name="spellchecker">
<str name="name">suggest_wfst</str>
<str name="classname">org.apache.solr.spelling.suggest.Suggester</str>
<str name="lookupImpl">org.apache.solr.spelling.suggest.fst.WFSTLookupFactory</str>
<str name="field">suggest</str>
<str name="storeDir">suggest_wfst</str>
<str name="buildOnCommit">true</str>
<!-- Suggester properties -->
<bool name="exactMatchFirst">true</bool>
</lst>
</searchComponent>
<!-- The default (jaspell) -->
<requestHandler class="org.apache.solr.handler.component.SearchHandler" name="/suggest">
@ -112,4 +127,16 @@
</arr>
</requestHandler>
<!-- wfst (finite state automaton based) -->
<requestHandler class="org.apache.solr.handler.component.SearchHandler" name="/suggest_wfst">
<lst name="defaults">
<str name="spellcheck">true</str>
<str name="spellcheck.dictionary">suggest_wfst</str>
<str name="spellcheck.collate">false</str>
</lst>
<arr name="components">
<str>suggest_wfst</str>
</arr>
</requestHandler>
</config>

View File

@ -1,5 +1,22 @@
package org.apache.solr.spelling.suggest;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
public class SuggesterFSTTest extends SuggesterTest {
public SuggesterFSTTest() {
super.requestUri = "/suggest_fst";

View File

@ -1,5 +1,22 @@
package org.apache.solr.spelling.suggest;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
public class SuggesterTSTTest extends SuggesterTest {
public SuggesterTSTTest() {
super.requestUri = "/suggest_tst";

View File

@ -0,0 +1,24 @@
package org.apache.solr.spelling.suggest;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
public class SuggesterWFSTTest extends SuggesterTest {
public SuggesterWFSTTest() {
super.requestUri = "/suggest_wfst";
}
}