mirror of https://github.com/apache/lucene.git
LUCENE-8920: refactor FST binary search
This commit is contained in:
parent
fe0c042470
commit
92d4e712d5
|
@ -81,7 +81,7 @@ public final class BytesRefFSTEnum<T> extends FSTEnum<T> {
|
|||
public InputOutput<T> seekExact(BytesRef target) throws IOException {
|
||||
this.target = target;
|
||||
targetLength = target.length;
|
||||
if (super.doSeekExact()) {
|
||||
if (doSeekExact()) {
|
||||
assert upto == 1+target.length;
|
||||
return setResult();
|
||||
} else {
|
||||
|
|
|
@ -207,36 +207,12 @@ abstract class FSTEnum<T> {
|
|||
|
||||
private FST.Arc<T> doSeekCeilArrayPacked(final FST.Arc<T> arc, final int targetLabel, final FST.BytesReader in) throws IOException {
|
||||
// The array is packed -- use binary search to find the target.
|
||||
|
||||
int low = arc.arcIdx();
|
||||
int high = arc.numArcs() -1;
|
||||
int mid = 0;
|
||||
//System.out.println("do arc array low=" + low + " high=" + high + " targetLabel=" + targetLabel);
|
||||
boolean found = false;
|
||||
while (low <= high) {
|
||||
mid = (low + high) >>> 1;
|
||||
in.setPosition(arc.posArcsStart());
|
||||
in.skipBytes(arc.bytesPerArc() * mid + 1);
|
||||
final int midLabel = fst.readLabel(in);
|
||||
final int cmp = midLabel - targetLabel;
|
||||
//System.out.println(" cycle low=" + low + " high=" + high + " mid=" + mid + " midLabel=" + midLabel + " cmp=" + cmp);
|
||||
if (cmp < 0)
|
||||
low = mid + 1;
|
||||
else if (cmp > 0)
|
||||
high = mid - 1;
|
||||
else {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: this code is dup'd w/ the code below (in
|
||||
// the outer else clause):
|
||||
if (found) {
|
||||
int idx = Util.binarySearch(fst, arc, targetLabel);
|
||||
if (idx >= 0) {
|
||||
// Match
|
||||
fst.readArcByIndex(arc, in, mid);
|
||||
assert arc.arcIdx() == mid;
|
||||
assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + mid;
|
||||
fst.readArcByIndex(arc, in, idx);
|
||||
assert arc.arcIdx() == idx;
|
||||
assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + idx;
|
||||
output[upto] = fst.outputs.add(output[upto-1], arc.output());
|
||||
if (targetLabel == FST.END_LABEL) {
|
||||
return null;
|
||||
|
@ -244,9 +220,11 @@ abstract class FSTEnum<T> {
|
|||
setCurrentLabel(arc.label());
|
||||
incr();
|
||||
return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
|
||||
} else if (low == arc.numArcs()) {
|
||||
}
|
||||
idx = -1 - idx;
|
||||
if (idx == arc.numArcs()) {
|
||||
// Dead end
|
||||
fst.readArcByIndex(arc, in, arc.numArcs() - 1);
|
||||
fst.readArcByIndex(arc, in, idx - 1);
|
||||
assert arc.isLast();
|
||||
// Dead end (target is after the last arc);
|
||||
// rollback to last fork then push
|
||||
|
@ -265,7 +243,8 @@ abstract class FSTEnum<T> {
|
|||
upto--;
|
||||
}
|
||||
} else {
|
||||
fst.readArcByIndex(arc, in, low);
|
||||
// Ceiling - arc with least higher label
|
||||
fst.readArcByIndex(arc, in, idx);
|
||||
assert arc.label() > targetLabel;
|
||||
pushFirst();
|
||||
return null;
|
||||
|
@ -314,7 +293,7 @@ abstract class FSTEnum<T> {
|
|||
// Todo: should we return a status here (SEEK_FOUND / SEEK_NOT_FOUND /
|
||||
// SEEK_END)? saves the eq check above?
|
||||
/** Seeks to largest term that's <= target. */
|
||||
protected void doSeekFloor() throws IOException {
|
||||
void doSeekFloor() throws IOException {
|
||||
|
||||
// TODO: possibly caller could/should provide common
|
||||
// prefix length? ie this work may be redundant if
|
||||
|
@ -417,37 +396,14 @@ abstract class FSTEnum<T> {
|
|||
|
||||
private FST.Arc<T> doSeekFloorArrayPacked(FST.Arc<T> arc, int targetLabel, final FST.BytesReader in) throws IOException {
|
||||
// Arcs are fixed array -- use binary search to find the target.
|
||||
int idx = Util.binarySearch(fst, arc, targetLabel);
|
||||
|
||||
int low = arc.arcIdx();
|
||||
int high = arc.numArcs() -1;
|
||||
int mid = 0;
|
||||
//System.out.println("do arc array low=" + low + " high=" + high + " targetLabel=" + targetLabel);
|
||||
boolean found = false;
|
||||
while (low <= high) {
|
||||
mid = (low + high) >>> 1;
|
||||
in.setPosition(arc.posArcsStart());
|
||||
in.skipBytes(arc.bytesPerArc() * mid + 1);
|
||||
final int midLabel = fst.readLabel(in);
|
||||
final int cmp = midLabel - targetLabel;
|
||||
//System.out.println(" cycle low=" + low + " high=" + high + " mid=" + mid + " midLabel=" + midLabel + " cmp=" + cmp);
|
||||
if (cmp < 0) {
|
||||
low = mid + 1;
|
||||
} else if (cmp > 0) {
|
||||
high = mid - 1;
|
||||
} else {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: this code is dup'd w/ the code below (in
|
||||
// the outer else clause):
|
||||
if (found) {
|
||||
if (idx >= 0) {
|
||||
// Match -- recurse
|
||||
//System.out.println(" match! arcIdx=" + mid);
|
||||
fst.readArcByIndex(arc, in, mid);
|
||||
assert arc.arcIdx() == mid;
|
||||
assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + mid;
|
||||
//System.out.println(" match! arcIdx=" + idx);
|
||||
fst.readArcByIndex(arc, in, idx);
|
||||
assert arc.arcIdx() == idx;
|
||||
assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + idx;
|
||||
output[upto] = fst.outputs.add(output[upto-1], arc.output());
|
||||
if (targetLabel == FST.END_LABEL) {
|
||||
return null;
|
||||
|
@ -455,7 +411,7 @@ abstract class FSTEnum<T> {
|
|||
setCurrentLabel(arc.label());
|
||||
incr();
|
||||
return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
|
||||
} else if (high == -1) {
|
||||
} else if (idx == -1) {
|
||||
//System.out.println(" before first");
|
||||
// Very first arc is after our target
|
||||
// TODO: if each arc could somehow read the arc just
|
||||
|
@ -483,8 +439,8 @@ abstract class FSTEnum<T> {
|
|||
arc = getArc(upto);
|
||||
}
|
||||
} else {
|
||||
// There is a floor arc:
|
||||
fst.readArcByIndex(arc, in, high);
|
||||
// There is a floor arc; idx will be {@code -1 - (floor + 1)}.
|
||||
fst.readArcByIndex(arc, in, -2 - idx);
|
||||
assert arc.isLast() || fst.readNextArcLabel(arc, in) > targetLabel;
|
||||
assert arc.label() < targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel;
|
||||
pushLast();
|
||||
|
@ -652,4 +608,5 @@ abstract class FSTEnum<T> {
|
|||
}
|
||||
return arcs[idx];
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ public final class IntsRefFSTEnum<T> extends FSTEnum<T> {
|
|||
public InputOutput<T> seekExact(IntsRef target) throws IOException {
|
||||
this.target = target;
|
||||
targetLength = target.length;
|
||||
if (super.doSeekExact()) {
|
||||
if (doSeekExact()) {
|
||||
assert upto == 1+target.length;
|
||||
return setResult();
|
||||
} else {
|
||||
|
|
|
@ -951,35 +951,17 @@ public final class Util {
|
|||
return fst.readArcAtPosition(arc, in, arc.posArcsStart() - offset * arc.bytesPerArc());
|
||||
}
|
||||
}
|
||||
// Arcs are packed array -- use binary search to find
|
||||
// the target.
|
||||
|
||||
int low = arc.arcIdx();
|
||||
int mid = 0;
|
||||
int high = arc.numArcs() - 1;
|
||||
// System.out.println("do arc array low=" + low + " high=" + high +
|
||||
// " targetLabel=" + targetLabel);
|
||||
while (low <= high) {
|
||||
mid = (low + high) >>> 1;
|
||||
in.setPosition(arc.posArcsStart());
|
||||
in.skipBytes(arc.bytesPerArc() * mid + 1);
|
||||
final int midLabel = fst.readLabel(in);
|
||||
final int cmp = midLabel - label;
|
||||
// System.out.println(" cycle low=" + low + " high=" + high + " mid=" +
|
||||
// mid + " midLabel=" + midLabel + " cmp=" + cmp);
|
||||
if (cmp < 0) {
|
||||
low = mid + 1;
|
||||
} else if (cmp > 0) {
|
||||
high = mid - 1;
|
||||
} else {
|
||||
return fst.readArcByIndex(arc, in, mid);
|
||||
// Arcs are packed array -- use binary search to find the target.
|
||||
int idx = binarySearch(fst, arc, label);
|
||||
if (idx >= 0) {
|
||||
return fst.readArcByIndex(arc, in, idx);
|
||||
}
|
||||
}
|
||||
if (low == arc.numArcs()) {
|
||||
idx = -1 - idx;
|
||||
if (idx == arc.numArcs()) {
|
||||
// DEAD END!
|
||||
return null;
|
||||
}
|
||||
return fst.readArcByIndex(arc, in , high + 1);
|
||||
return fst.readArcByIndex(arc, in , idx);
|
||||
}
|
||||
|
||||
// Linear scan
|
||||
|
@ -1001,4 +983,36 @@ public final class Util {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform a binary search of Arcs encoded as a packed array
|
||||
* @param fst the FST from which to read
|
||||
* @param arc the starting arc; sibling arcs greater than this will be searched. Usually the first arc in the array.
|
||||
* @param targetLabel the label to search for
|
||||
* @param <T> the output type of the FST
|
||||
* @return the index of the Arc having the target label, or if no Arc has the matching label, {@code -1 - idx)},
|
||||
* where {@code idx} is the index of the Arc with the next highest label, or the total number of arcs
|
||||
* if the target label exceeds the maximum.
|
||||
* @throws IOException when the FST reader does
|
||||
*/
|
||||
static <T> int binarySearch(FST<T> fst, FST.Arc<T> arc, int targetLabel) throws IOException {
|
||||
BytesReader in = fst.getBytesReader();
|
||||
int low = arc.arcIdx();
|
||||
int mid = 0;
|
||||
int high = arc.numArcs() -1;
|
||||
while (low <= high) {
|
||||
mid = (low + high) >>> 1;
|
||||
in.setPosition(arc.posArcsStart());
|
||||
in.skipBytes(arc.bytesPerArc() * mid + 1);
|
||||
final int midLabel = fst.readLabel(in);
|
||||
final int cmp = midLabel - targetLabel;
|
||||
if (cmp < 0) {
|
||||
low = mid + 1;
|
||||
} else if (cmp > 0) {
|
||||
high = mid - 1;
|
||||
} else {
|
||||
return mid;
|
||||
}
|
||||
}
|
||||
return -1 - low;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util.fst;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IntsRefBuilder;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
|
||||
public class TestUtil extends LuceneTestCase {
|
||||
|
||||
public void testBinarySearch() throws Exception {
|
||||
// Creates a node with 8 arcs spanning (z-A) = 57 chars that will be encoded as a sparse array (no gaps)
|
||||
// requiring binary search
|
||||
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
|
||||
FST<Object> fst = buildFST(letters, true);
|
||||
FST.Arc<Object> arc = fst.getFirstArc(new FST.Arc<>());
|
||||
arc = fst.readFirstTargetArc(arc, arc, fst.getBytesReader());
|
||||
for (int i = 0; i < letters.size(); i++) {
|
||||
assertEquals(i, Util.binarySearch(fst, arc, letters.get(i).charAt(0)));
|
||||
}
|
||||
// before the first
|
||||
assertEquals(-1, Util.binarySearch(fst, arc, ' '));
|
||||
// after the last
|
||||
assertEquals(-1 - letters.size(), Util.binarySearch(fst, arc, '~'));
|
||||
assertEquals(-2, Util.binarySearch(fst, arc, 'B'));
|
||||
assertEquals(-2, Util.binarySearch(fst, arc, 'C'));
|
||||
assertEquals(-7, Util.binarySearch(fst, arc, 'P'));
|
||||
}
|
||||
|
||||
public void testReadCeilArcPackedArray() throws Exception {
|
||||
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
|
||||
verifyReadCeilArc(letters, true);
|
||||
}
|
||||
|
||||
public void testReadCeilArcArrayWithGaps() throws Exception {
|
||||
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T");
|
||||
verifyReadCeilArc(letters, true);
|
||||
}
|
||||
|
||||
public void testReadCeilArcList() throws Exception {
|
||||
List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
|
||||
verifyReadCeilArc(letters, false);
|
||||
}
|
||||
|
||||
private void verifyReadCeilArc(List<String> letters, boolean allowArrayArcs) throws Exception {
|
||||
FST<Object> fst = buildFST(letters, allowArrayArcs);
|
||||
FST.Arc<Object> first = fst.getFirstArc(new FST.Arc<>());
|
||||
FST.Arc<Object> arc = new FST.Arc<>();
|
||||
FST.BytesReader in = fst.getBytesReader();
|
||||
for (String letter : letters) {
|
||||
char c = letter.charAt(0);
|
||||
arc = Util.readCeilArc(c, fst, first, arc, in);
|
||||
assertNotNull(arc);
|
||||
assertEquals(c, arc.label());
|
||||
}
|
||||
// before the first
|
||||
assertEquals('A', Util.readCeilArc(' ', fst, first, arc, in).label());
|
||||
// after the last
|
||||
assertNull(Util.readCeilArc('~', fst, first, arc, in));
|
||||
// in the middle
|
||||
assertEquals('J', Util.readCeilArc('F', fst, first, arc, in).label());
|
||||
// no following arcs
|
||||
assertNull(Util.readCeilArc('Z', fst, arc, arc, in));
|
||||
}
|
||||
|
||||
private FST<Object> buildFST(List<String> words, boolean allowArrayArcs) throws Exception {
|
||||
final Outputs<Object> outputs = NoOutputs.getSingleton();
|
||||
final Builder<Object> b = new Builder<>(FST.INPUT_TYPE.BYTE1, 0, 0, true, true, Integer.MAX_VALUE, outputs, allowArrayArcs, 15);
|
||||
|
||||
for (String word : words) {
|
||||
b.add(Util.toIntsRef(new BytesRef(word), new IntsRefBuilder()), outputs.getNoOutput());
|
||||
}
|
||||
return b.finish();
|
||||
}
|
||||
|
||||
private List<String> createRandomDictionary(int width, int depth) {
|
||||
return createRandomDictionary(new ArrayList<>(), new StringBuilder(), width, depth);
|
||||
}
|
||||
|
||||
private List<String> createRandomDictionary(List<String> dict, StringBuilder buf, int width, int depth) {
|
||||
char c = (char) random().nextInt(128);
|
||||
assert width < Character.MIN_SURROGATE / 8 - 128; // avoid surrogate chars
|
||||
int len = buf.length();
|
||||
for (int i = 0; i < width; i++) {
|
||||
buf.append(c);
|
||||
if (depth > 0) {
|
||||
createRandomDictionary(dict, buf, width, depth - 1);
|
||||
} else {
|
||||
dict.add(buf.toString());
|
||||
}
|
||||
c += random().nextInt(8);
|
||||
buf.setLength(len);
|
||||
}
|
||||
return dict;
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue