LUCENE-3801: generify FST shortestPaths to any output type

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1296237 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Robert Muir 2012-03-02 14:59:44 +00:00
parent 1ccea6ab0f
commit a273239db6
4 changed files with 241 additions and 81 deletions

View File

@ -859,7 +859,7 @@ New Features
* LUCENE-3725: Added optional packing to FST building; this uses extra
RAM during building but results in a smaller FST. (Mike McCandless)
* LUCENE-3714: Add top N shortest cost paths search for FST<Long>.
* LUCENE-3714: Add top N shortest cost paths search for FST.
(Robert Muir, Dawid Weiss, Mike McCandless)
Bug fixes

View File

@ -83,11 +83,6 @@ public final class Util {
}
}
// TODO: parameterize the FST type <T> and allow passing in a
// comparator; eg maybe your output is a PairOutput and
// one of the outputs in the pair is monotonic so you
// compare by that
/** Reverse lookup (lookup by output instead of by input),
* in the special case when your FSTs outputs are
* strictly ascending. This locates the input/output
@ -133,7 +128,7 @@ public final class Util {
}
}
if (fst.targetHasArcs(arc)) {
if (FST.targetHasArcs(arc)) {
//System.out.println(" targetHasArcs");
if (result.ints.length == upto) {
result.grow(1+upto);
@ -155,7 +150,7 @@ public final class Util {
final byte flags = in.readByte();
fst.readLabel(in);
final long minArcOutput;
if ((flags & fst.BIT_ARC_HAS_OUTPUT) != 0) {
if ((flags & FST.BIT_ARC_HAS_OUTPUT) != 0) {
final long arcOutput = fst.outputs.read(in);
minArcOutput = output + arcOutput;
} else {
@ -235,14 +230,16 @@ public final class Util {
}
}
private static class FSTPath implements Comparable<FSTPath> {
public FST.Arc<Long> arc;
public long cost;
private static class FSTPath<T> implements Comparable<FSTPath<T>> {
public FST.Arc<T> arc;
public T cost;
public final IntsRef input = new IntsRef();
final Comparator<T> comparator;
public FSTPath(long cost, FST.Arc<Long> arc) {
this.arc = new FST.Arc<Long>().copyFrom(arc);
public FSTPath(T cost, FST.Arc<T> arc, Comparator<T> comparator) {
this.arc = new FST.Arc<T>().copyFrom(arc);
this.cost = cost;
this.comparator = comparator;
}
@Override
@ -251,48 +248,50 @@ public final class Util {
}
@Override
public int compareTo(FSTPath other) {
if (cost < other.cost) {
return -1;
} else if (cost > other.cost) {
return 1;
} else {
public int compareTo(FSTPath<T> other) {
int cmp = comparator.compare(cost, other.cost);
if (cmp == 0) {
return input.compareTo(other.input);
} else {
return cmp;
}
}
}
private static class TopNSearcher {
private static class TopNSearcher<T> {
private final FST<Long> fst;
private final FST.Arc<Long> fromNode;
private final FST<T> fst;
private final FST.Arc<T> fromNode;
private final int topN;
final Comparator<T> comparator;
// Set once the queue has filled:
FSTPath bottom = null;
FSTPath<T> bottom = null;
TreeSet<FSTPath> queue = null;
TreeSet<FSTPath<T>> queue = null;
public TopNSearcher(FST<Long> fst, FST.Arc<Long> fromNode, int topN) {
public TopNSearcher(FST<T> fst, FST.Arc<T> fromNode, int topN, Comparator<T> comparator) {
this.fst = fst;
this.topN = topN;
this.fromNode = fromNode;
this.comparator = comparator;
}
// If back plus this arc is competitive then add to queue:
private void addIfCompetitive(FSTPath path) {
private void addIfCompetitive(FSTPath<T> path) {
assert queue != null;
long cost = path.cost + path.arc.output;
T cost = fst.outputs.add(path.cost, path.arc.output);
//System.out.println(" addIfCompetitive bottom=" + bottom + " queue.size()=" + queue.size());
if (bottom != null) {
if (cost > bottom.cost) {
int comp = comparator.compare(cost, bottom.cost);
if (comp > 0) {
// Doesn't compete
return;
} else if (cost == bottom.cost) {
} else if (comp == 0) {
// Tie break by alpha sort on the input:
path.input.grow(path.input.length+1);
path.input.ints[path.input.length++] = path.arc.label;
@ -309,7 +308,7 @@ public final class Util {
// Queue isn't full yet, so any path we hit competes:
}
final FSTPath newPath = new FSTPath(cost, path.arc);
final FSTPath<T> newPath = new FSTPath<T>(cost, path.arc, comparator);
newPath.input.grow(path.input.length+1);
System.arraycopy(path.input.ints, 0, newPath.input.ints, 0, path.input.length);
@ -319,7 +318,7 @@ public final class Util {
//System.out.println(" add path=" + newPath);
queue.add(newPath);
if (bottom != null) {
final FSTPath removed = queue.pollLast();
final FSTPath<T> removed = queue.pollLast();
assert removed == bottom;
bottom = queue.last();
//System.out.println(" now re-set bottom: " + bottom + " queue=" + queue);
@ -330,13 +329,13 @@ public final class Util {
}
}
public MinResult[] search() throws IOException {
public MinResult<T>[] search() throws IOException {
//System.out.println(" search topN=" + topN);
final FST.Arc<Long> scratchArc = new FST.Arc<Long>();
final FST.Arc<T> scratchArc = new FST.Arc<T>();
final List<MinResult> results = new ArrayList<MinResult>();
final List<MinResult<T>> results = new ArrayList<MinResult<T>>();
final Long NO_OUTPUT = fst.outputs.getNoOutput();
final T NO_OUTPUT = fst.outputs.getNoOutput();
// TODO: we could enable FST to sorting arcs by weight
// as it freezes... can easily do this on first pass
@ -349,7 +348,7 @@ public final class Util {
while (results.size() < topN) {
//System.out.println("\nfind next path");
FSTPath path;
FSTPath<T> path;
if (queue == null) {
@ -360,20 +359,20 @@ public final class Util {
// First pass (top path): start from original fromNode
if (topN > 1) {
queue = new TreeSet<FSTPath>();
queue = new TreeSet<FSTPath<T>>();
}
long minArcCost = Long.MAX_VALUE;
FST.Arc<Long> minArc = null;
T minArcCost = null;
FST.Arc<T> minArc = null;
path = new FSTPath(0, fromNode);
path = new FSTPath<T>(NO_OUTPUT, fromNode, comparator);
fst.readFirstTargetArc(fromNode, path.arc);
// Bootstrap: find the min starting arc
while (true) {
long arcScore = path.arc.output;
T arcScore = path.arc.output;
//System.out.println(" arc=" + (char) path.arc.label + " cost=" + arcScore);
if (arcScore < minArcCost) {
if (minArcCost == null || comparator.compare(arcScore, minArcCost) < 0) {
minArcCost = arcScore;
minArc = scratchArc.copyFrom(path.arc);
//System.out.println(" **");
@ -419,7 +418,7 @@ public final class Util {
//System.out.println(" empty string! cost=" + path.cost);
// Empty string!
path.input.length--;
results.add(new MinResult(path.input, path.cost));
results.add(new MinResult<T>(path.input, path.cost, comparator));
continue;
}
@ -439,15 +438,16 @@ public final class Util {
// For each input letter:
while (true) {
//System.out.println("\n cycle path: " + path);
//System.out.println("\n cycle path: " + path);
fst.readFirstTargetArc(path.arc, path.arc);
// For each arc leaving this node:
boolean foundZero = false;
while(true) {
//System.out.println(" arc=" + (char) path.arc.label + " cost=" + path.arc.output);
if (path.arc.output == NO_OUTPUT) {
// tricky: instead of comparing output == 0, we must
// express it via the comparator compare(output, 0) == 0
if (comparator.compare(NO_OUTPUT, path.arc.output) == 0) {
if (queue == null) {
foundZero = true;
break;
@ -479,13 +479,13 @@ public final class Util {
if (path.arc.label == FST.END_LABEL) {
// Add final output:
//System.out.println(" done!: " + path);
results.add(new MinResult(path.input, path.cost + path.arc.output));
results.add(new MinResult<T>(path.input, fst.outputs.add(path.cost, path.arc.output), comparator));
break;
} else {
path.input.grow(1+path.input.length);
path.input.ints[path.input.length] = path.arc.label;
path.input.length++;
path.cost += path.arc.output;
path.cost = fst.outputs.add(path.cost, path.arc.output);
}
}
}
@ -494,40 +494,36 @@ public final class Util {
}
}
// TODO: parameterize the FST type <T> and allow passing in a
// comparator; eg maybe your output is a PairOutput and
// one of the outputs in the pair is monotonic so you
// compare by that
public final static class MinResult implements Comparable<MinResult> {
public final static class MinResult<T> implements Comparable<MinResult<T>> {
public final IntsRef input;
public final long output;
public MinResult(IntsRef input, long output) {
public final T output;
final Comparator<T> comparator;
public MinResult(IntsRef input, T output, Comparator<T> comparator) {
this.input = input;
this.output = output;
this.comparator = comparator;
}
@Override
public int compareTo(MinResult other) {
if (output < other.output) {
return -1;
} else if (output > other.output) {
return 1;
} else {
public int compareTo(MinResult<T> other) {
int cmp = comparator.compare(output, other.output);
if (cmp == 0) {
return input.compareTo(other.input);
} else {
return cmp;
}
}
}
/** Starting from node, find the top N min cost (Long
* output) completions to a final node.
/** Starting from node, find the top N min cost
* completions to a final node.
*
* <p>NOTE: you must share the outputs when you build the
* FST (pass doShare=true to {@link
* PositiveIntOutputs#getSingleton}). */
public static MinResult[] shortestPaths(FST<Long> fst, FST.Arc<Long> fromNode, int topN) throws IOException {
return new TopNSearcher(fst, fromNode, topN).search();
public static <T> MinResult<T>[] shortestPaths(FST<T> fst, FST.Arc<T> fromNode, Comparator<T> comparator, int topN) throws IOException {
return new TopNSearcher<T>(fst, fromNode, topN, comparator).search();
}
/**
@ -639,7 +635,7 @@ public final class Util {
while (!thisLevelQueue.isEmpty()) {
final FST.Arc<T> arc = thisLevelQueue.remove(thisLevelQueue.size() - 1);
//System.out.println(" pop: " + arc);
if (fst.targetHasArcs(arc)) {
if (FST.targetHasArcs(arc)) {
// scan all target arcs
//System.out.println(" readFirstTarget...");
final int node = arc.target;
@ -694,7 +690,7 @@ public final class Util {
outs = "";
}
if (!fst.targetHasArcs(arc) && arc.isFinal() && arc.nextFinalOutput != NO_OUTPUT) {
if (!FST.targetHasArcs(arc) && arc.isFinal() && arc.nextFinalOutput != NO_OUTPUT) {
// Tricky special case: sometimes, due to
// pruning, the builder can [sillily] produce
// an FST with an arc into the final end state

View File

@ -59,6 +59,7 @@ import org.apache.lucene.util.UnicodeUtil;
import org.apache.lucene.util._TestUtil;
import org.apache.lucene.util.fst.FST.Arc;
import org.apache.lucene.util.fst.FST.BytesReader;
import org.apache.lucene.util.fst.PairOutputs.Pair;
@UseNoMemoryExpensiveCodec
public class TestFSTs extends LuceneTestCase {
@ -1975,6 +1976,12 @@ public class TestFSTs extends LuceneTestCase {
assertFalse(arc.isFinal());
assertEquals(42, arc.output.longValue());
}
static final Comparator<Long> minLongComparator = new Comparator<Long> () {
public int compare(Long left, Long right) {
return left.compareTo(right);
}
};
public void testShortestPaths() throws Exception {
final PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton(true);
@ -1989,19 +1996,65 @@ public class TestFSTs extends LuceneTestCase {
//Util.toDot(fst, w, false, false);
//w.close();
Util.MinResult[] r = Util.shortestPaths(fst,
Util.MinResult<Long>[] r = Util.shortestPaths(fst,
fst.getFirstArc(new FST.Arc<Long>()),
minLongComparator,
3);
assertEquals(3, r.length);
assertEquals(Util.toIntsRef(new BytesRef("aac"), scratch), r[0].input);
assertEquals(7, r[0].output);
assertEquals(7L, r[0].output.longValue());
assertEquals(Util.toIntsRef(new BytesRef("ax"), scratch), r[1].input);
assertEquals(17, r[1].output);
assertEquals(17L, r[1].output.longValue());
assertEquals(Util.toIntsRef(new BytesRef("aab"), scratch), r[2].input);
assertEquals(22, r[2].output);
assertEquals(22L, r[2].output.longValue());
}
// compares just the weight side of the pair
static final Comparator<Pair<Long,Long>> minPairWeightComparator = new Comparator<Pair<Long,Long>> () {
public int compare(Pair<Long,Long> left, Pair<Long,Long> right) {
return left.output1.compareTo(right.output1);
}
};
/** like testShortestPaths, but uses pairoutputs so we have both a weight and an output */
public void testShortestPathsWFST() throws Exception {
PairOutputs<Long,Long> outputs = new PairOutputs<Long,Long>(
PositiveIntOutputs.getSingleton(true), // weight
PositiveIntOutputs.getSingleton(true) // output
);
final Builder<Pair<Long,Long>> builder = new Builder<Pair<Long,Long>>(FST.INPUT_TYPE.BYTE1, outputs);
final IntsRef scratch = new IntsRef();
builder.add(Util.toIntsRef(new BytesRef("aab"), scratch), outputs.newPair(22L, 57L));
builder.add(Util.toIntsRef(new BytesRef("aac"), scratch), outputs.newPair(7L, 36L));
builder.add(Util.toIntsRef(new BytesRef("ax"), scratch), outputs.newPair(17L, 85L));
final FST<Pair<Long,Long>> fst = builder.finish();
//Writer w = new OutputStreamWriter(new FileOutputStream("out.dot"));
//Util.toDot(fst, w, false, false);
//w.close();
Util.MinResult<Pair<Long,Long>>[] r = Util.shortestPaths(fst,
fst.getFirstArc(new FST.Arc<Pair<Long,Long>>()),
minPairWeightComparator,
3);
assertEquals(3, r.length);
assertEquals(Util.toIntsRef(new BytesRef("aac"), scratch), r[0].input);
assertEquals(7L, r[0].output.output1.longValue()); // weight
assertEquals(36L, r[0].output.output2.longValue()); // output
assertEquals(Util.toIntsRef(new BytesRef("ax"), scratch), r[1].input);
assertEquals(17L, r[1].output.output1.longValue()); // weight
assertEquals(85L, r[1].output.output2.longValue()); // output
assertEquals(Util.toIntsRef(new BytesRef("aab"), scratch), r[2].input);
assertEquals(22L, r[2].output.output1.longValue()); // weight
assertEquals(57L, r[2].output.output2.longValue()); // output
}
public void testShortestPathsRandom() throws Exception {
@ -2059,17 +2112,121 @@ public class TestFSTs extends LuceneTestCase {
final int topN = _TestUtil.nextInt(random, 1, 10);
Util.MinResult[] r = Util.shortestPaths(fst, arc, topN);
Util.MinResult<Long>[] r = Util.shortestPaths(fst, arc, minLongComparator, topN);
// 2. go thru whole treemap (slowCompletor) and check its actually the best suggestion
final List<Util.MinResult> matches = new ArrayList<Util.MinResult>();
final List<Util.MinResult<Long>> matches = new ArrayList<Util.MinResult<Long>>();
// TODO: could be faster... but its slowCompletor for a reason
for (Map.Entry<String,Long> e : slowCompletor.entrySet()) {
if (e.getKey().startsWith(prefix)) {
//System.out.println(" consider " + e.getKey());
matches.add(new Util.MinResult(Util.toIntsRef(new BytesRef(e.getKey().substring(prefix.length())), new IntsRef()),
e.getValue() - prefixOutput));
matches.add(new Util.MinResult<Long>(Util.toIntsRef(new BytesRef(e.getKey().substring(prefix.length())), new IntsRef()),
e.getValue() - prefixOutput, minLongComparator));
}
}
assertTrue(matches.size() > 0);
Collections.sort(matches);
if (matches.size() > topN) {
matches.subList(topN, matches.size()).clear();
}
assertEquals(matches.size(), r.length);
for(int hit=0;hit<r.length;hit++) {
//System.out.println(" check hit " + hit);
assertEquals(matches.get(hit).input, r[hit].input);
assertEquals(matches.get(hit).output, r[hit].output);
}
}
}
// used by slowcompletor
class TwoLongs {
long a;
long b;
TwoLongs(long a, long b) {
this.a = a;
this.b = b;
}
}
/** like testShortestPathsRandom, but uses pairoutputs so we have both a weight and an output */
public void testShortestPathsWFSTRandom() throws Exception {
int numWords = atLeast(1000);
final TreeMap<String,TwoLongs> slowCompletor = new TreeMap<String,TwoLongs>();
final TreeSet<String> allPrefixes = new TreeSet<String>();
PairOutputs<Long,Long> outputs = new PairOutputs<Long,Long>(
PositiveIntOutputs.getSingleton(true), // weight
PositiveIntOutputs.getSingleton(true) // output
);
final Builder<Pair<Long,Long>> builder = new Builder<Pair<Long,Long>>(FST.INPUT_TYPE.BYTE1, outputs);
final IntsRef scratch = new IntsRef();
for (int i = 0; i < numWords; i++) {
String s;
while (true) {
s = _TestUtil.randomSimpleString(random);
if (!slowCompletor.containsKey(s)) {
break;
}
}
for (int j = 1; j < s.length(); j++) {
allPrefixes.add(s.substring(0, j));
}
int weight = _TestUtil.nextInt(random, 1, 100); // weights 1..100
int output = _TestUtil.nextInt(random, 0, 500); // outputs 0..500
slowCompletor.put(s, new TwoLongs(weight, output));
}
for (Map.Entry<String,TwoLongs> e : slowCompletor.entrySet()) {
//System.out.println("add: " + e);
long weight = e.getValue().a;
long output = e.getValue().b;
builder.add(Util.toIntsRef(new BytesRef(e.getKey()), scratch), outputs.newPair(weight, output));
}
final FST<Pair<Long,Long>> fst = builder.finish();
//System.out.println("SAVE out.dot");
//Writer w = new OutputStreamWriter(new FileOutputStream("out.dot"));
//Util.toDot(fst, w, false, false);
//w.close();
BytesReader reader = fst.getBytesReader(0);
//System.out.println("testing: " + allPrefixes.size() + " prefixes");
for (String prefix : allPrefixes) {
// 1. run prefix against fst, then complete by value
//System.out.println("TEST: " + prefix);
Pair<Long,Long> prefixOutput = outputs.getNoOutput();
FST.Arc<Pair<Long,Long>> arc = fst.getFirstArc(new FST.Arc<Pair<Long,Long>>());
for(int idx=0;idx<prefix.length();idx++) {
if (fst.findTargetArc((int) prefix.charAt(idx), arc, arc, reader) == null) {
fail();
}
prefixOutput = outputs.add(prefixOutput, arc.output);
}
final int topN = _TestUtil.nextInt(random, 1, 10);
Util.MinResult<Pair<Long,Long>>[] r = Util.shortestPaths(fst, arc, minPairWeightComparator, topN);
// 2. go thru whole treemap (slowCompletor) and check its actually the best suggestion
final List<Util.MinResult<Pair<Long,Long>>> matches = new ArrayList<Util.MinResult<Pair<Long,Long>>>();
// TODO: could be faster... but its slowCompletor for a reason
for (Map.Entry<String,TwoLongs> e : slowCompletor.entrySet()) {
if (e.getKey().startsWith(prefix)) {
//System.out.println(" consider " + e.getKey());
matches.add(new Util.MinResult<Pair<Long,Long>>(Util.toIntsRef(new BytesRef(e.getKey().substring(prefix.length())), new IntsRef()),
outputs.newPair(e.getValue().a - prefixOutput.output1, e.getValue().b - prefixOutput.output2),
minPairWeightComparator));
}
}

View File

@ -23,6 +23,7 @@ import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.search.spell.TermFreqIterator;
@ -55,7 +56,7 @@ import org.apache.lucene.util.fst.Util.MinResult;
* Input weights will be cast to a java integer, and any
* negative, infinite, or NaN values will be rejected.
*
* @see Util#shortestPaths(FST, FST.Arc, int)
* @see Util#shortestPaths(FST, FST.Arc, Comparator, int)
* @lucene.experimental
*/
public class WFSTCompletionLookup extends Lookup {
@ -230,13 +231,13 @@ public class WFSTCompletionLookup extends Lookup {
}
// complete top-N
MinResult completions[] = null;
MinResult<Long> completions[] = null;
try {
completions = Util.shortestPaths(fst, arc, num);
completions = Util.shortestPaths(fst, arc, weightComparator, num);
} catch (IOException bogus) { throw new RuntimeException(bogus); }
BytesRef suffix = new BytesRef(8);
for (MinResult completion : completions) {
for (MinResult<Long> completion : completions) {
scratch.length = prefixLength;
// append suffix
Util.toBytesRef(completion.input, suffix);
@ -304,4 +305,10 @@ public class WFSTCompletionLookup extends Lookup {
}
return Integer.MAX_VALUE - (int)value;
}
static final Comparator<Long> weightComparator = new Comparator<Long> () {
public int compare(Long left, Long right) {
return left.compareTo(right);
}
};
}