LUCENE-8920: refactor FST binary search

This commit is contained in:
Michael Sokolov 2019-07-04 10:45:17 -04:00 committed by Michael Sokolov
parent fe0c042470
commit 92d4e712d5
6 changed files with 180 additions and 94 deletions

View File

@ -81,7 +81,7 @@ public final class BytesRefFSTEnum<T> extends FSTEnum<T> {
public InputOutput<T> seekExact(BytesRef target) throws IOException { public InputOutput<T> seekExact(BytesRef target) throws IOException {
this.target = target; this.target = target;
targetLength = target.length; targetLength = target.length;
if (super.doSeekExact()) { if (doSeekExact()) {
assert upto == 1+target.length; assert upto == 1+target.length;
return setResult(); return setResult();
} else { } else {

View File

@ -207,36 +207,12 @@ abstract class FSTEnum<T> {
private FST.Arc<T> doSeekCeilArrayPacked(final FST.Arc<T> arc, final int targetLabel, final FST.BytesReader in) throws IOException { private FST.Arc<T> doSeekCeilArrayPacked(final FST.Arc<T> arc, final int targetLabel, final FST.BytesReader in) throws IOException {
// The array is packed -- use binary search to find the target. // The array is packed -- use binary search to find the target.
int idx = Util.binarySearch(fst, arc, targetLabel);
int low = arc.arcIdx(); if (idx >= 0) {
int high = arc.numArcs() -1;
int mid = 0;
//System.out.println("do arc array low=" + low + " high=" + high + " targetLabel=" + targetLabel);
boolean found = false;
while (low <= high) {
mid = (low + high) >>> 1;
in.setPosition(arc.posArcsStart());
in.skipBytes(arc.bytesPerArc() * mid + 1);
final int midLabel = fst.readLabel(in);
final int cmp = midLabel - targetLabel;
//System.out.println(" cycle low=" + low + " high=" + high + " mid=" + mid + " midLabel=" + midLabel + " cmp=" + cmp);
if (cmp < 0)
low = mid + 1;
else if (cmp > 0)
high = mid - 1;
else {
found = true;
break;
}
}
// NOTE: this code is dup'd w/ the code below (in
// the outer else clause):
if (found) {
// Match // Match
fst.readArcByIndex(arc, in, mid); fst.readArcByIndex(arc, in, idx);
assert arc.arcIdx() == mid; assert arc.arcIdx() == idx;
assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + mid; assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + idx;
output[upto] = fst.outputs.add(output[upto-1], arc.output()); output[upto] = fst.outputs.add(output[upto-1], arc.output());
if (targetLabel == FST.END_LABEL) { if (targetLabel == FST.END_LABEL) {
return null; return null;
@ -244,9 +220,11 @@ abstract class FSTEnum<T> {
setCurrentLabel(arc.label()); setCurrentLabel(arc.label());
incr(); incr();
return fst.readFirstTargetArc(arc, getArc(upto), fstReader); return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
} else if (low == arc.numArcs()) { }
idx = -1 - idx;
if (idx == arc.numArcs()) {
// Dead end // Dead end
fst.readArcByIndex(arc, in, arc.numArcs() - 1); fst.readArcByIndex(arc, in, idx - 1);
assert arc.isLast(); assert arc.isLast();
// Dead end (target is after the last arc); // Dead end (target is after the last arc);
// rollback to last fork then push // rollback to last fork then push
@ -265,7 +243,8 @@ abstract class FSTEnum<T> {
upto--; upto--;
} }
} else { } else {
fst.readArcByIndex(arc, in, low); // Ceiling - arc with least higher label
fst.readArcByIndex(arc, in, idx);
assert arc.label() > targetLabel; assert arc.label() > targetLabel;
pushFirst(); pushFirst();
return null; return null;
@ -314,7 +293,7 @@ abstract class FSTEnum<T> {
// Todo: should we return a status here (SEEK_FOUND / SEEK_NOT_FOUND / // Todo: should we return a status here (SEEK_FOUND / SEEK_NOT_FOUND /
// SEEK_END)? saves the eq check above? // SEEK_END)? saves the eq check above?
/** Seeks to largest term that's &lt;= target. */ /** Seeks to largest term that's &lt;= target. */
protected void doSeekFloor() throws IOException { void doSeekFloor() throws IOException {
// TODO: possibly caller could/should provide common // TODO: possibly caller could/should provide common
// prefix length? ie this work may be redundant if // prefix length? ie this work may be redundant if
@ -417,37 +396,14 @@ abstract class FSTEnum<T> {
private FST.Arc<T> doSeekFloorArrayPacked(FST.Arc<T> arc, int targetLabel, final FST.BytesReader in) throws IOException { private FST.Arc<T> doSeekFloorArrayPacked(FST.Arc<T> arc, int targetLabel, final FST.BytesReader in) throws IOException {
// Arcs are fixed array -- use binary search to find the target. // Arcs are fixed array -- use binary search to find the target.
int idx = Util.binarySearch(fst, arc, targetLabel);
int low = arc.arcIdx(); if (idx >= 0) {
int high = arc.numArcs() -1;
int mid = 0;
//System.out.println("do arc array low=" + low + " high=" + high + " targetLabel=" + targetLabel);
boolean found = false;
while (low <= high) {
mid = (low + high) >>> 1;
in.setPosition(arc.posArcsStart());
in.skipBytes(arc.bytesPerArc() * mid + 1);
final int midLabel = fst.readLabel(in);
final int cmp = midLabel - targetLabel;
//System.out.println(" cycle low=" + low + " high=" + high + " mid=" + mid + " midLabel=" + midLabel + " cmp=" + cmp);
if (cmp < 0) {
low = mid + 1;
} else if (cmp > 0) {
high = mid - 1;
} else {
found = true;
break;
}
}
// NOTE: this code is dup'd w/ the code below (in
// the outer else clause):
if (found) {
// Match -- recurse // Match -- recurse
//System.out.println(" match! arcIdx=" + mid); //System.out.println(" match! arcIdx=" + idx);
fst.readArcByIndex(arc, in, mid); fst.readArcByIndex(arc, in, idx);
assert arc.arcIdx() == mid; assert arc.arcIdx() == idx;
assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + mid; assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + idx;
output[upto] = fst.outputs.add(output[upto-1], arc.output()); output[upto] = fst.outputs.add(output[upto-1], arc.output());
if (targetLabel == FST.END_LABEL) { if (targetLabel == FST.END_LABEL) {
return null; return null;
@ -455,7 +411,7 @@ abstract class FSTEnum<T> {
setCurrentLabel(arc.label()); setCurrentLabel(arc.label());
incr(); incr();
return fst.readFirstTargetArc(arc, getArc(upto), fstReader); return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
} else if (high == -1) { } else if (idx == -1) {
//System.out.println(" before first"); //System.out.println(" before first");
// Very first arc is after our target // Very first arc is after our target
// TODO: if each arc could somehow read the arc just // TODO: if each arc could somehow read the arc just
@ -483,8 +439,8 @@ abstract class FSTEnum<T> {
arc = getArc(upto); arc = getArc(upto);
} }
} else { } else {
// There is a floor arc: // There is a floor arc; idx will be {@code -1 - (floor + 1)}.
fst.readArcByIndex(arc, in, high); fst.readArcByIndex(arc, in, -2 - idx);
assert arc.isLast() || fst.readNextArcLabel(arc, in) > targetLabel; assert arc.isLast() || fst.readNextArcLabel(arc, in) > targetLabel;
assert arc.label() < targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel; assert arc.label() < targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel;
pushLast(); pushLast();
@ -652,4 +608,5 @@ abstract class FSTEnum<T> {
} }
return arcs[idx]; return arcs[idx];
} }
} }

View File

@ -81,7 +81,7 @@ public final class IntsRefFSTEnum<T> extends FSTEnum<T> {
public InputOutput<T> seekExact(IntsRef target) throws IOException { public InputOutput<T> seekExact(IntsRef target) throws IOException {
this.target = target; this.target = target;
targetLength = target.length; targetLength = target.length;
if (super.doSeekExact()) { if (doSeekExact()) {
assert upto == 1+target.length; assert upto == 1+target.length;
return setResult(); return setResult();
} else { } else {

View File

@ -951,35 +951,17 @@ public final class Util {
return fst.readArcAtPosition(arc, in, arc.posArcsStart() - offset * arc.bytesPerArc()); return fst.readArcAtPosition(arc, in, arc.posArcsStart() - offset * arc.bytesPerArc());
} }
} }
// Arcs are packed array -- use binary search to find // Arcs are packed array -- use binary search to find the target.
// the target. int idx = binarySearch(fst, arc, label);
if (idx >= 0) {
int low = arc.arcIdx(); return fst.readArcByIndex(arc, in, idx);
int mid = 0;
int high = arc.numArcs() - 1;
// System.out.println("do arc array low=" + low + " high=" + high +
// " targetLabel=" + targetLabel);
while (low <= high) {
mid = (low + high) >>> 1;
in.setPosition(arc.posArcsStart());
in.skipBytes(arc.bytesPerArc() * mid + 1);
final int midLabel = fst.readLabel(in);
final int cmp = midLabel - label;
// System.out.println(" cycle low=" + low + " high=" + high + " mid=" +
// mid + " midLabel=" + midLabel + " cmp=" + cmp);
if (cmp < 0) {
low = mid + 1;
} else if (cmp > 0) {
high = mid - 1;
} else {
return fst.readArcByIndex(arc, in, mid);
} }
} idx = -1 - idx;
if (low == arc.numArcs()) { if (idx == arc.numArcs()) {
// DEAD END! // DEAD END!
return null; return null;
} }
return fst.readArcByIndex(arc, in , high + 1); return fst.readArcByIndex(arc, in , idx);
} }
// Linear scan // Linear scan
@ -1001,4 +983,36 @@ public final class Util {
} }
} }
/**
* Perform a binary search of Arcs encoded as a packed array
* @param fst the FST from which to read
* @param arc the starting arc; sibling arcs greater than this will be searched. Usually the first arc in the array.
* @param targetLabel the label to search for
* @param <T> the output type of the FST
* @return the index of the Arc having the target label, or if no Arc has the matching label, {@code -1 - idx)},
* where {@code idx} is the index of the Arc with the next highest label, or the total number of arcs
* if the target label exceeds the maximum.
* @throws IOException when the FST reader does
*/
static <T> int binarySearch(FST<T> fst, FST.Arc<T> arc, int targetLabel) throws IOException {
BytesReader in = fst.getBytesReader();
int low = arc.arcIdx();
int mid = 0;
int high = arc.numArcs() -1;
while (low <= high) {
mid = (low + high) >>> 1;
in.setPosition(arc.posArcsStart());
in.skipBytes(arc.bytesPerArc() * mid + 1);
final int midLabel = fst.readLabel(in);
final int cmp = midLabel - targetLabel;
if (cmp < 0) {
low = mid + 1;
} else if (cmp > 0) {
high = mid - 1;
} else {
return mid;
}
}
return -1 - low;
}
} }

View File

@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.fst;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.LuceneTestCase;
public class TestUtil extends LuceneTestCase {
public void testBinarySearch() throws Exception {
// Creates a node with 8 arcs spanning (z-A) = 57 chars that will be encoded as a sparse array (no gaps)
// requiring binary search
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
FST<Object> fst = buildFST(letters, true);
FST.Arc<Object> arc = fst.getFirstArc(new FST.Arc<>());
arc = fst.readFirstTargetArc(arc, arc, fst.getBytesReader());
for (int i = 0; i < letters.size(); i++) {
assertEquals(i, Util.binarySearch(fst, arc, letters.get(i).charAt(0)));
}
// before the first
assertEquals(-1, Util.binarySearch(fst, arc, ' '));
// after the last
assertEquals(-1 - letters.size(), Util.binarySearch(fst, arc, '~'));
assertEquals(-2, Util.binarySearch(fst, arc, 'B'));
assertEquals(-2, Util.binarySearch(fst, arc, 'C'));
assertEquals(-7, Util.binarySearch(fst, arc, 'P'));
}
public void testReadCeilArcPackedArray() throws Exception {
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
verifyReadCeilArc(letters, true);
}
public void testReadCeilArcArrayWithGaps() throws Exception {
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T");
verifyReadCeilArc(letters, true);
}
public void testReadCeilArcList() throws Exception {
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
verifyReadCeilArc(letters, false);
}
private void verifyReadCeilArc(List<String> letters, boolean allowArrayArcs) throws Exception {
FST<Object> fst = buildFST(letters, allowArrayArcs);
FST.Arc<Object> first = fst.getFirstArc(new FST.Arc<>());
FST.Arc<Object> arc = new FST.Arc<>();
FST.BytesReader in = fst.getBytesReader();
for (String letter : letters) {
char c = letter.charAt(0);
arc = Util.readCeilArc(c, fst, first, arc, in);
assertNotNull(arc);
assertEquals(c, arc.label());
}
// before the first
assertEquals('A', Util.readCeilArc(' ', fst, first, arc, in).label());
// after the last
assertNull(Util.readCeilArc('~', fst, first, arc, in));
// in the middle
assertEquals('J', Util.readCeilArc('F', fst, first, arc, in).label());
// no following arcs
assertNull(Util.readCeilArc('Z', fst, arc, arc, in));
}
private FST<Object> buildFST(List<String> words, boolean allowArrayArcs) throws Exception {
final Outputs<Object> outputs = NoOutputs.getSingleton();
final Builder<Object> b = new Builder<>(FST.INPUT_TYPE.BYTE1, 0, 0, true, true, Integer.MAX_VALUE, outputs, allowArrayArcs, 15);
for (String word : words) {
b.add(Util.toIntsRef(new BytesRef(word), new IntsRefBuilder()), outputs.getNoOutput());
}
return b.finish();
}
private List<String> createRandomDictionary(int width, int depth) {
return createRandomDictionary(new ArrayList<>(), new StringBuilder(), width, depth);
}
private List<String> createRandomDictionary(List<String> dict, StringBuilder buf, int width, int depth) {
char c = (char) random().nextInt(128);
assert width < Character.MIN_SURROGATE / 8 - 128; // avoid surrogate chars
int len = buf.length();
for (int i = 0; i < width; i++) {
buf.append(c);
if (depth > 0) {
createRandomDictionary(dict, buf, width, depth - 1);
} else {
dict.add(buf.toString());
}
c += random().nextInt(8);
buf.setLength(len);
}
return dict;
}
}