mirror of https://github.com/apache/lucene.git
LUCENE-8920: refactor FST binary search
This commit is contained in:
parent
fe0c042470
commit
92d4e712d5
|
@ -81,7 +81,7 @@ public final class BytesRefFSTEnum<T> extends FSTEnum<T> {
|
||||||
public InputOutput<T> seekExact(BytesRef target) throws IOException {
|
public InputOutput<T> seekExact(BytesRef target) throws IOException {
|
||||||
this.target = target;
|
this.target = target;
|
||||||
targetLength = target.length;
|
targetLength = target.length;
|
||||||
if (super.doSeekExact()) {
|
if (doSeekExact()) {
|
||||||
assert upto == 1+target.length;
|
assert upto == 1+target.length;
|
||||||
return setResult();
|
return setResult();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -207,36 +207,12 @@ abstract class FSTEnum<T> {
|
||||||
|
|
||||||
private FST.Arc<T> doSeekCeilArrayPacked(final FST.Arc<T> arc, final int targetLabel, final FST.BytesReader in) throws IOException {
|
private FST.Arc<T> doSeekCeilArrayPacked(final FST.Arc<T> arc, final int targetLabel, final FST.BytesReader in) throws IOException {
|
||||||
// The array is packed -- use binary search to find the target.
|
// The array is packed -- use binary search to find the target.
|
||||||
|
int idx = Util.binarySearch(fst, arc, targetLabel);
|
||||||
int low = arc.arcIdx();
|
if (idx >= 0) {
|
||||||
int high = arc.numArcs() -1;
|
|
||||||
int mid = 0;
|
|
||||||
//System.out.println("do arc array low=" + low + " high=" + high + " targetLabel=" + targetLabel);
|
|
||||||
boolean found = false;
|
|
||||||
while (low <= high) {
|
|
||||||
mid = (low + high) >>> 1;
|
|
||||||
in.setPosition(arc.posArcsStart());
|
|
||||||
in.skipBytes(arc.bytesPerArc() * mid + 1);
|
|
||||||
final int midLabel = fst.readLabel(in);
|
|
||||||
final int cmp = midLabel - targetLabel;
|
|
||||||
//System.out.println(" cycle low=" + low + " high=" + high + " mid=" + mid + " midLabel=" + midLabel + " cmp=" + cmp);
|
|
||||||
if (cmp < 0)
|
|
||||||
low = mid + 1;
|
|
||||||
else if (cmp > 0)
|
|
||||||
high = mid - 1;
|
|
||||||
else {
|
|
||||||
found = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: this code is dup'd w/ the code below (in
|
|
||||||
// the outer else clause):
|
|
||||||
if (found) {
|
|
||||||
// Match
|
// Match
|
||||||
fst.readArcByIndex(arc, in, mid);
|
fst.readArcByIndex(arc, in, idx);
|
||||||
assert arc.arcIdx() == mid;
|
assert arc.arcIdx() == idx;
|
||||||
assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + mid;
|
assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + idx;
|
||||||
output[upto] = fst.outputs.add(output[upto-1], arc.output());
|
output[upto] = fst.outputs.add(output[upto-1], arc.output());
|
||||||
if (targetLabel == FST.END_LABEL) {
|
if (targetLabel == FST.END_LABEL) {
|
||||||
return null;
|
return null;
|
||||||
|
@ -244,9 +220,11 @@ abstract class FSTEnum<T> {
|
||||||
setCurrentLabel(arc.label());
|
setCurrentLabel(arc.label());
|
||||||
incr();
|
incr();
|
||||||
return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
|
return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
|
||||||
} else if (low == arc.numArcs()) {
|
}
|
||||||
|
idx = -1 - idx;
|
||||||
|
if (idx == arc.numArcs()) {
|
||||||
// Dead end
|
// Dead end
|
||||||
fst.readArcByIndex(arc, in, arc.numArcs() - 1);
|
fst.readArcByIndex(arc, in, idx - 1);
|
||||||
assert arc.isLast();
|
assert arc.isLast();
|
||||||
// Dead end (target is after the last arc);
|
// Dead end (target is after the last arc);
|
||||||
// rollback to last fork then push
|
// rollback to last fork then push
|
||||||
|
@ -265,7 +243,8 @@ abstract class FSTEnum<T> {
|
||||||
upto--;
|
upto--;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fst.readArcByIndex(arc, in, low);
|
// Ceiling - arc with least higher label
|
||||||
|
fst.readArcByIndex(arc, in, idx);
|
||||||
assert arc.label() > targetLabel;
|
assert arc.label() > targetLabel;
|
||||||
pushFirst();
|
pushFirst();
|
||||||
return null;
|
return null;
|
||||||
|
@ -314,7 +293,7 @@ abstract class FSTEnum<T> {
|
||||||
// Todo: should we return a status here (SEEK_FOUND / SEEK_NOT_FOUND /
|
// Todo: should we return a status here (SEEK_FOUND / SEEK_NOT_FOUND /
|
||||||
// SEEK_END)? saves the eq check above?
|
// SEEK_END)? saves the eq check above?
|
||||||
/** Seeks to largest term that's <= target. */
|
/** Seeks to largest term that's <= target. */
|
||||||
protected void doSeekFloor() throws IOException {
|
void doSeekFloor() throws IOException {
|
||||||
|
|
||||||
// TODO: possibly caller could/should provide common
|
// TODO: possibly caller could/should provide common
|
||||||
// prefix length? ie this work may be redundant if
|
// prefix length? ie this work may be redundant if
|
||||||
|
@ -417,37 +396,14 @@ abstract class FSTEnum<T> {
|
||||||
|
|
||||||
private FST.Arc<T> doSeekFloorArrayPacked(FST.Arc<T> arc, int targetLabel, final FST.BytesReader in) throws IOException {
|
private FST.Arc<T> doSeekFloorArrayPacked(FST.Arc<T> arc, int targetLabel, final FST.BytesReader in) throws IOException {
|
||||||
// Arcs are fixed array -- use binary search to find the target.
|
// Arcs are fixed array -- use binary search to find the target.
|
||||||
|
int idx = Util.binarySearch(fst, arc, targetLabel);
|
||||||
|
|
||||||
int low = arc.arcIdx();
|
if (idx >= 0) {
|
||||||
int high = arc.numArcs() -1;
|
|
||||||
int mid = 0;
|
|
||||||
//System.out.println("do arc array low=" + low + " high=" + high + " targetLabel=" + targetLabel);
|
|
||||||
boolean found = false;
|
|
||||||
while (low <= high) {
|
|
||||||
mid = (low + high) >>> 1;
|
|
||||||
in.setPosition(arc.posArcsStart());
|
|
||||||
in.skipBytes(arc.bytesPerArc() * mid + 1);
|
|
||||||
final int midLabel = fst.readLabel(in);
|
|
||||||
final int cmp = midLabel - targetLabel;
|
|
||||||
//System.out.println(" cycle low=" + low + " high=" + high + " mid=" + mid + " midLabel=" + midLabel + " cmp=" + cmp);
|
|
||||||
if (cmp < 0) {
|
|
||||||
low = mid + 1;
|
|
||||||
} else if (cmp > 0) {
|
|
||||||
high = mid - 1;
|
|
||||||
} else {
|
|
||||||
found = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: this code is dup'd w/ the code below (in
|
|
||||||
// the outer else clause):
|
|
||||||
if (found) {
|
|
||||||
// Match -- recurse
|
// Match -- recurse
|
||||||
//System.out.println(" match! arcIdx=" + mid);
|
//System.out.println(" match! arcIdx=" + idx);
|
||||||
fst.readArcByIndex(arc, in, mid);
|
fst.readArcByIndex(arc, in, idx);
|
||||||
assert arc.arcIdx() == mid;
|
assert arc.arcIdx() == idx;
|
||||||
assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + mid;
|
assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + idx;
|
||||||
output[upto] = fst.outputs.add(output[upto-1], arc.output());
|
output[upto] = fst.outputs.add(output[upto-1], arc.output());
|
||||||
if (targetLabel == FST.END_LABEL) {
|
if (targetLabel == FST.END_LABEL) {
|
||||||
return null;
|
return null;
|
||||||
|
@ -455,7 +411,7 @@ abstract class FSTEnum<T> {
|
||||||
setCurrentLabel(arc.label());
|
setCurrentLabel(arc.label());
|
||||||
incr();
|
incr();
|
||||||
return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
|
return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
|
||||||
} else if (high == -1) {
|
} else if (idx == -1) {
|
||||||
//System.out.println(" before first");
|
//System.out.println(" before first");
|
||||||
// Very first arc is after our target
|
// Very first arc is after our target
|
||||||
// TODO: if each arc could somehow read the arc just
|
// TODO: if each arc could somehow read the arc just
|
||||||
|
@ -483,8 +439,8 @@ abstract class FSTEnum<T> {
|
||||||
arc = getArc(upto);
|
arc = getArc(upto);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// There is a floor arc:
|
// There is a floor arc; idx will be {@code -1 - (floor + 1)}.
|
||||||
fst.readArcByIndex(arc, in, high);
|
fst.readArcByIndex(arc, in, -2 - idx);
|
||||||
assert arc.isLast() || fst.readNextArcLabel(arc, in) > targetLabel;
|
assert arc.isLast() || fst.readNextArcLabel(arc, in) > targetLabel;
|
||||||
assert arc.label() < targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel;
|
assert arc.label() < targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel;
|
||||||
pushLast();
|
pushLast();
|
||||||
|
@ -652,4 +608,5 @@ abstract class FSTEnum<T> {
|
||||||
}
|
}
|
||||||
return arcs[idx];
|
return arcs[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,7 +81,7 @@ public final class IntsRefFSTEnum<T> extends FSTEnum<T> {
|
||||||
public InputOutput<T> seekExact(IntsRef target) throws IOException {
|
public InputOutput<T> seekExact(IntsRef target) throws IOException {
|
||||||
this.target = target;
|
this.target = target;
|
||||||
targetLength = target.length;
|
targetLength = target.length;
|
||||||
if (super.doSeekExact()) {
|
if (doSeekExact()) {
|
||||||
assert upto == 1+target.length;
|
assert upto == 1+target.length;
|
||||||
return setResult();
|
return setResult();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -951,35 +951,17 @@ public final class Util {
|
||||||
return fst.readArcAtPosition(arc, in, arc.posArcsStart() - offset * arc.bytesPerArc());
|
return fst.readArcAtPosition(arc, in, arc.posArcsStart() - offset * arc.bytesPerArc());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Arcs are packed array -- use binary search to find
|
// Arcs are packed array -- use binary search to find the target.
|
||||||
// the target.
|
int idx = binarySearch(fst, arc, label);
|
||||||
|
if (idx >= 0) {
|
||||||
int low = arc.arcIdx();
|
return fst.readArcByIndex(arc, in, idx);
|
||||||
int mid = 0;
|
|
||||||
int high = arc.numArcs() - 1;
|
|
||||||
// System.out.println("do arc array low=" + low + " high=" + high +
|
|
||||||
// " targetLabel=" + targetLabel);
|
|
||||||
while (low <= high) {
|
|
||||||
mid = (low + high) >>> 1;
|
|
||||||
in.setPosition(arc.posArcsStart());
|
|
||||||
in.skipBytes(arc.bytesPerArc() * mid + 1);
|
|
||||||
final int midLabel = fst.readLabel(in);
|
|
||||||
final int cmp = midLabel - label;
|
|
||||||
// System.out.println(" cycle low=" + low + " high=" + high + " mid=" +
|
|
||||||
// mid + " midLabel=" + midLabel + " cmp=" + cmp);
|
|
||||||
if (cmp < 0) {
|
|
||||||
low = mid + 1;
|
|
||||||
} else if (cmp > 0) {
|
|
||||||
high = mid - 1;
|
|
||||||
} else {
|
|
||||||
return fst.readArcByIndex(arc, in, mid);
|
|
||||||
}
|
}
|
||||||
}
|
idx = -1 - idx;
|
||||||
if (low == arc.numArcs()) {
|
if (idx == arc.numArcs()) {
|
||||||
// DEAD END!
|
// DEAD END!
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return fst.readArcByIndex(arc, in , high + 1);
|
return fst.readArcByIndex(arc, in , idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Linear scan
|
// Linear scan
|
||||||
|
@ -1001,4 +983,36 @@ public final class Util {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform a binary search of Arcs encoded as a packed array
|
||||||
|
* @param fst the FST from which to read
|
||||||
|
* @param arc the starting arc; sibling arcs greater than this will be searched. Usually the first arc in the array.
|
||||||
|
* @param targetLabel the label to search for
|
||||||
|
* @param <T> the output type of the FST
|
||||||
|
* @return the index of the Arc having the target label, or if no Arc has the matching label, {@code -1 - idx)},
|
||||||
|
* where {@code idx} is the index of the Arc with the next highest label, or the total number of arcs
|
||||||
|
* if the target label exceeds the maximum.
|
||||||
|
* @throws IOException when the FST reader does
|
||||||
|
*/
|
||||||
|
static <T> int binarySearch(FST<T> fst, FST.Arc<T> arc, int targetLabel) throws IOException {
|
||||||
|
BytesReader in = fst.getBytesReader();
|
||||||
|
int low = arc.arcIdx();
|
||||||
|
int mid = 0;
|
||||||
|
int high = arc.numArcs() -1;
|
||||||
|
while (low <= high) {
|
||||||
|
mid = (low + high) >>> 1;
|
||||||
|
in.setPosition(arc.posArcsStart());
|
||||||
|
in.skipBytes(arc.bytesPerArc() * mid + 1);
|
||||||
|
final int midLabel = fst.readLabel(in);
|
||||||
|
final int cmp = midLabel - targetLabel;
|
||||||
|
if (cmp < 0) {
|
||||||
|
low = mid + 1;
|
||||||
|
} else if (cmp > 0) {
|
||||||
|
high = mid - 1;
|
||||||
|
} else {
|
||||||
|
return mid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1 - low;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.util.fst;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.apache.lucene.util.IntsRefBuilder;
|
||||||
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
|
|
||||||
|
public class TestUtil extends LuceneTestCase {
|
||||||
|
|
||||||
|
public void testBinarySearch() throws Exception {
|
||||||
|
// Creates a node with 8 arcs spanning (z-A) = 57 chars that will be encoded as a sparse array (no gaps)
|
||||||
|
// requiring binary search
|
||||||
|
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
|
||||||
|
FST<Object> fst = buildFST(letters, true);
|
||||||
|
FST.Arc<Object> arc = fst.getFirstArc(new FST.Arc<>());
|
||||||
|
arc = fst.readFirstTargetArc(arc, arc, fst.getBytesReader());
|
||||||
|
for (int i = 0; i < letters.size(); i++) {
|
||||||
|
assertEquals(i, Util.binarySearch(fst, arc, letters.get(i).charAt(0)));
|
||||||
|
}
|
||||||
|
// before the first
|
||||||
|
assertEquals(-1, Util.binarySearch(fst, arc, ' '));
|
||||||
|
// after the last
|
||||||
|
assertEquals(-1 - letters.size(), Util.binarySearch(fst, arc, '~'));
|
||||||
|
assertEquals(-2, Util.binarySearch(fst, arc, 'B'));
|
||||||
|
assertEquals(-2, Util.binarySearch(fst, arc, 'C'));
|
||||||
|
assertEquals(-7, Util.binarySearch(fst, arc, 'P'));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testReadCeilArcPackedArray() throws Exception {
|
||||||
|
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
|
||||||
|
verifyReadCeilArc(letters, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testReadCeilArcArrayWithGaps() throws Exception {
|
||||||
|
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T");
|
||||||
|
verifyReadCeilArc(letters, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testReadCeilArcList() throws Exception {
|
||||||
|
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
|
||||||
|
verifyReadCeilArc(letters, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void verifyReadCeilArc(List<String> letters, boolean allowArrayArcs) throws Exception {
|
||||||
|
FST<Object> fst = buildFST(letters, allowArrayArcs);
|
||||||
|
FST.Arc<Object> first = fst.getFirstArc(new FST.Arc<>());
|
||||||
|
FST.Arc<Object> arc = new FST.Arc<>();
|
||||||
|
FST.BytesReader in = fst.getBytesReader();
|
||||||
|
for (String letter : letters) {
|
||||||
|
char c = letter.charAt(0);
|
||||||
|
arc = Util.readCeilArc(c, fst, first, arc, in);
|
||||||
|
assertNotNull(arc);
|
||||||
|
assertEquals(c, arc.label());
|
||||||
|
}
|
||||||
|
// before the first
|
||||||
|
assertEquals('A', Util.readCeilArc(' ', fst, first, arc, in).label());
|
||||||
|
// after the last
|
||||||
|
assertNull(Util.readCeilArc('~', fst, first, arc, in));
|
||||||
|
// in the middle
|
||||||
|
assertEquals('J', Util.readCeilArc('F', fst, first, arc, in).label());
|
||||||
|
// no following arcs
|
||||||
|
assertNull(Util.readCeilArc('Z', fst, arc, arc, in));
|
||||||
|
}
|
||||||
|
|
||||||
|
private FST<Object> buildFST(List<String> words, boolean allowArrayArcs) throws Exception {
|
||||||
|
final Outputs<Object> outputs = NoOutputs.getSingleton();
|
||||||
|
final Builder<Object> b = new Builder<>(FST.INPUT_TYPE.BYTE1, 0, 0, true, true, Integer.MAX_VALUE, outputs, allowArrayArcs, 15);
|
||||||
|
|
||||||
|
for (String word : words) {
|
||||||
|
b.add(Util.toIntsRef(new BytesRef(word), new IntsRefBuilder()), outputs.getNoOutput());
|
||||||
|
}
|
||||||
|
return b.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> createRandomDictionary(int width, int depth) {
|
||||||
|
return createRandomDictionary(new ArrayList<>(), new StringBuilder(), width, depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> createRandomDictionary(List<String> dict, StringBuilder buf, int width, int depth) {
|
||||||
|
char c = (char) random().nextInt(128);
|
||||||
|
assert width < Character.MIN_SURROGATE / 8 - 128; // avoid surrogate chars
|
||||||
|
int len = buf.length();
|
||||||
|
for (int i = 0; i < width; i++) {
|
||||||
|
buf.append(c);
|
||||||
|
if (depth > 0) {
|
||||||
|
createRandomDictionary(dict, buf, width, depth - 1);
|
||||||
|
} else {
|
||||||
|
dict.add(buf.toString());
|
||||||
|
}
|
||||||
|
c += random().nextInt(8);
|
||||||
|
buf.setLength(len);
|
||||||
|
}
|
||||||
|
return dict;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue