From 92d4e712d5d50d745c5a6c10dacda66198974116 Mon Sep 17 00:00:00 2001 From: Michael Sokolov <sokolov@apache.org> Date: Thu, 4 Jul 2019 10:45:17 -0400 Subject: [PATCH] LUCENE-8920: refactor FST binary search --- .../codecs/memory/FSTOrdTermsReader.java | 2 +- .../lucene/util/fst/BytesRefFSTEnum.java | 2 +- .../org/apache/lucene/util/fst/FSTEnum.java | 89 ++++---------- .../lucene/util/fst/IntsRefFSTEnum.java | 2 +- .../java/org/apache/lucene/util/fst/Util.java | 64 ++++++---- .../org/apache/lucene/util/fst/TestUtil.java | 115 ++++++++++++++++++ 6 files changed, 180 insertions(+), 94 deletions(-) create mode 100644 lucene/core/src/test/org/apache/lucene/util/fst/TestUtil.java diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/memory/FSTOrdTermsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/memory/FSTOrdTermsReader.java index daba6096c93..0fa8ebe0524 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/memory/FSTOrdTermsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/memory/FSTOrdTermsReader.java @@ -305,7 +305,7 @@ public class FSTOrdTermsReader extends FieldsProducer { } // Only wraps common operations for PBF interact - abstract class BaseTermsEnum extends org.apache.lucene.index.BaseTermsEnum { + abstract class BaseTermsEnum extends org.apache.lucene.index.BaseTermsEnum { /* Current term's ord, starts from 0 */ long ord; diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/BytesRefFSTEnum.java b/lucene/core/src/java/org/apache/lucene/util/fst/BytesRefFSTEnum.java index 97d6a0eb357..19e2a01c575 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/BytesRefFSTEnum.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/BytesRefFSTEnum.java @@ -81,7 +81,7 @@ public final class BytesRefFSTEnum<T> extends FSTEnum<T> { public InputOutput<T> seekExact(BytesRef target) throws IOException { this.target = target; targetLength = target.length; - if (super.doSeekExact()) { + if (doSeekExact()) { assert upto == 1+target.length; return setResult(); } else { diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java b/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java index feeddf3fb6b..36d8ddd384a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java @@ -24,7 +24,7 @@ import org.apache.lucene.util.RamUsageEstimator; /** Can next() and advance() through the terms in an FST * - * @lucene.experimental + * @lucene.experimental */ abstract class FSTEnum<T> { @@ -207,36 +207,12 @@ abstract class FSTEnum<T> { private FST.Arc<T> doSeekCeilArrayPacked(final FST.Arc<T> arc, final int targetLabel, final FST.BytesReader in) throws IOException { // The array is packed -- use binary search to find the target. - - int low = arc.arcIdx(); - int high = arc.numArcs() -1; - int mid = 0; - //System.out.println("do arc array low=" + low + " high=" + high + " targetLabel=" + targetLabel); - boolean found = false; - while (low <= high) { - mid = (low + high) >>> 1; - in.setPosition(arc.posArcsStart()); - in.skipBytes(arc.bytesPerArc() * mid + 1); - final int midLabel = fst.readLabel(in); - final int cmp = midLabel - targetLabel; - //System.out.println(" cycle low=" + low + " high=" + high + " mid=" + mid + " midLabel=" + midLabel + " cmp=" + cmp); - if (cmp < 0) - low = mid + 1; - else if (cmp > 0) - high = mid - 1; - else { - found = true; - break; - } - } - - // NOTE: this code is dup'd w/ the code below (in - // the outer else clause): - if (found) { + int idx = Util.binarySearch(fst, arc, targetLabel); + if (idx >= 0) { // Match - fst.readArcByIndex(arc, in, mid); - assert arc.arcIdx() == mid; - assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + mid; + fst.readArcByIndex(arc, in, idx); + assert arc.arcIdx() == idx; + assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + idx; output[upto] = fst.outputs.add(output[upto-1], arc.output()); if (targetLabel == FST.END_LABEL) { return null; @@ -244,9 +220,11 @@ abstract class FSTEnum<T> { setCurrentLabel(arc.label()); incr(); return fst.readFirstTargetArc(arc, getArc(upto), fstReader); - } else if (low == arc.numArcs()) { + } + idx = -1 - idx; + if (idx == arc.numArcs()) { // Dead end - fst.readArcByIndex(arc, in, arc.numArcs() - 1); + fst.readArcByIndex(arc, in, idx - 1); assert arc.isLast(); // Dead end (target is after the last arc); // rollback to last fork then push @@ -265,7 +243,8 @@ abstract class FSTEnum<T> { upto--; } } else { - fst.readArcByIndex(arc, in, low); + // Ceiling - arc with least higher label + fst.readArcByIndex(arc, in, idx); assert arc.label() > targetLabel; pushFirst(); return null; @@ -314,7 +293,7 @@ abstract class FSTEnum<T> { // Todo: should we return a status here (SEEK_FOUND / SEEK_NOT_FOUND / // SEEK_END)? saves the eq check above? /** Seeks to largest term that's <= target. */ - protected void doSeekFloor() throws IOException { + void doSeekFloor() throws IOException { // TODO: possibly caller could/should provide common // prefix length? ie this work may be redundant if @@ -417,37 +396,14 @@ abstract class FSTEnum<T> { private FST.Arc<T> doSeekFloorArrayPacked(FST.Arc<T> arc, int targetLabel, final FST.BytesReader in) throws IOException { // Arcs are fixed array -- use binary search to find the target. + int idx = Util.binarySearch(fst, arc, targetLabel); - int low = arc.arcIdx(); - int high = arc.numArcs() -1; - int mid = 0; - //System.out.println("do arc array low=" + low + " high=" + high + " targetLabel=" + targetLabel); - boolean found = false; - while (low <= high) { - mid = (low + high) >>> 1; - in.setPosition(arc.posArcsStart()); - in.skipBytes(arc.bytesPerArc() * mid + 1); - final int midLabel = fst.readLabel(in); - final int cmp = midLabel - targetLabel; - //System.out.println(" cycle low=" + low + " high=" + high + " mid=" + mid + " midLabel=" + midLabel + " cmp=" + cmp); - if (cmp < 0) { - low = mid + 1; - } else if (cmp > 0) { - high = mid - 1; - } else { - found = true; - break; - } - } - - // NOTE: this code is dup'd w/ the code below (in - // the outer else clause): - if (found) { + if (idx >= 0) { // Match -- recurse - //System.out.println(" match! arcIdx=" + mid); - fst.readArcByIndex(arc, in, mid); - assert arc.arcIdx() == mid; - assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + mid; + //System.out.println(" match! arcIdx=" + idx); + fst.readArcByIndex(arc, in, idx); + assert arc.arcIdx() == idx; + assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + idx; output[upto] = fst.outputs.add(output[upto-1], arc.output()); if (targetLabel == FST.END_LABEL) { return null; @@ -455,7 +411,7 @@ abstract class FSTEnum<T> { setCurrentLabel(arc.label()); incr(); return fst.readFirstTargetArc(arc, getArc(upto), fstReader); - } else if (high == -1) { + } else if (idx == -1) { //System.out.println(" before first"); // Very first arc is after our target // TODO: if each arc could somehow read the arc just @@ -483,8 +439,8 @@ abstract class FSTEnum<T> { arc = getArc(upto); } } else { - // There is a floor arc: - fst.readArcByIndex(arc, in, high); + // There is a floor arc; idx will be {@code -1 - (floor + 1)}. + fst.readArcByIndex(arc, in, -2 - idx); assert arc.isLast() || fst.readNextArcLabel(arc, in) > targetLabel; assert arc.label() < targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel; pushLast(); @@ -652,4 +608,5 @@ abstract class FSTEnum<T> { } return arcs[idx]; } + } diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/IntsRefFSTEnum.java b/lucene/core/src/java/org/apache/lucene/util/fst/IntsRefFSTEnum.java index f485854a748..2c05c965fbd 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/IntsRefFSTEnum.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/IntsRefFSTEnum.java @@ -81,7 +81,7 @@ public final class IntsRefFSTEnum<T> extends FSTEnum<T> { public InputOutput<T> seekExact(IntsRef target) throws IOException { this.target = target; targetLength = target.length; - if (super.doSeekExact()) { + if (doSeekExact()) { assert upto == 1+target.length; return setResult(); } else { diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/Util.java b/lucene/core/src/java/org/apache/lucene/util/fst/Util.java index e033267ffbe..ddecaded731 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/Util.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/Util.java @@ -951,35 +951,17 @@ public final class Util { return fst.readArcAtPosition(arc, in, arc.posArcsStart() - offset * arc.bytesPerArc()); } } - // Arcs are packed array -- use binary search to find - // the target. - - int low = arc.arcIdx(); - int mid = 0; - int high = arc.numArcs() - 1; - // System.out.println("do arc array low=" + low + " high=" + high + - // " targetLabel=" + targetLabel); - while (low <= high) { - mid = (low + high) >>> 1; - in.setPosition(arc.posArcsStart()); - in.skipBytes(arc.bytesPerArc() * mid + 1); - final int midLabel = fst.readLabel(in); - final int cmp = midLabel - label; - // System.out.println(" cycle low=" + low + " high=" + high + " mid=" + - // mid + " midLabel=" + midLabel + " cmp=" + cmp); - if (cmp < 0) { - low = mid + 1; - } else if (cmp > 0) { - high = mid - 1; - } else { - return fst.readArcByIndex(arc, in, mid); - } + // Arcs are packed array -- use binary search to find the target. + int idx = binarySearch(fst, arc, label); + if (idx >= 0) { + return fst.readArcByIndex(arc, in, idx); } - if (low == arc.numArcs()) { + idx = -1 - idx; + if (idx == arc.numArcs()) { // DEAD END! return null; } - return fst.readArcByIndex(arc, in , high + 1); + return fst.readArcByIndex(arc, in , idx); } // Linear scan @@ -1001,4 +983,36 @@ public final class Util { } } + /** + * Perform a binary search of Arcs encoded as a packed array + * @param fst the FST from which to read + * @param arc the starting arc; sibling arcs greater than this will be searched. Usually the first arc in the array. + * @param targetLabel the label to search for + * @param <T> the output type of the FST + * @return the index of the Arc having the target label, or if no Arc has the matching label, {@code -1 - idx)}, + * where {@code idx} is the index of the Arc with the next highest label, or the total number of arcs + * if the target label exceeds the maximum. + * @throws IOException when the FST reader does + */ + static <T> int binarySearch(FST<T> fst, FST.Arc<T> arc, int targetLabel) throws IOException { + BytesReader in = fst.getBytesReader(); + int low = arc.arcIdx(); + int mid = 0; + int high = arc.numArcs() -1; + while (low <= high) { + mid = (low + high) >>> 1; + in.setPosition(arc.posArcsStart()); + in.skipBytes(arc.bytesPerArc() * mid + 1); + final int midLabel = fst.readLabel(in); + final int cmp = midLabel - targetLabel; + if (cmp < 0) { + low = mid + 1; + } else if (cmp > 0) { + high = mid - 1; + } else { + return mid; + } + } + return -1 - low; + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/fst/TestUtil.java b/lucene/core/src/test/org/apache/lucene/util/fst/TestUtil.java new file mode 100644 index 00000000000..5ec163e9cfe --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/fst/TestUtil.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util.fst; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.IntsRefBuilder; +import org.apache.lucene.util.LuceneTestCase; + +public class TestUtil extends LuceneTestCase { + + public void testBinarySearch() throws Exception { + // Creates a node with 8 arcs spanning (z-A) = 57 chars that will be encoded as a sparse array (no gaps) + // requiring binary search + List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z"); + FST<Object> fst = buildFST(letters, true); + FST.Arc<Object> arc = fst.getFirstArc(new FST.Arc<>()); + arc = fst.readFirstTargetArc(arc, arc, fst.getBytesReader()); + for (int i = 0; i < letters.size(); i++) { + assertEquals(i, Util.binarySearch(fst, arc, letters.get(i).charAt(0))); + } + // before the first + assertEquals(-1, Util.binarySearch(fst, arc, ' ')); + // after the last + assertEquals(-1 - letters.size(), Util.binarySearch(fst, arc, '~')); + assertEquals(-2, Util.binarySearch(fst, arc, 'B')); + assertEquals(-2, Util.binarySearch(fst, arc, 'C')); + assertEquals(-7, Util.binarySearch(fst, arc, 'P')); + } + + public void testReadCeilArcPackedArray() throws Exception { + List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z"); + verifyReadCeilArc(letters, true); + } + + public void testReadCeilArcArrayWithGaps() throws Exception { + List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T"); + verifyReadCeilArc(letters, true); + } + + public void testReadCeilArcList() throws Exception { + List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z"); + verifyReadCeilArc(letters, false); + } + + private void verifyReadCeilArc(List<String> letters, boolean allowArrayArcs) throws Exception { + FST<Object> fst = buildFST(letters, allowArrayArcs); + FST.Arc<Object> first = fst.getFirstArc(new FST.Arc<>()); + FST.Arc<Object> arc = new FST.Arc<>(); + FST.BytesReader in = fst.getBytesReader(); + for (String letter : letters) { + char c = letter.charAt(0); + arc = Util.readCeilArc(c, fst, first, arc, in); + assertNotNull(arc); + assertEquals(c, arc.label()); + } + // before the first + assertEquals('A', Util.readCeilArc(' ', fst, first, arc, in).label()); + // after the last + assertNull(Util.readCeilArc('~', fst, first, arc, in)); + // in the middle + assertEquals('J', Util.readCeilArc('F', fst, first, arc, in).label()); + // no following arcs + assertNull(Util.readCeilArc('Z', fst, arc, arc, in)); + } + + private FST<Object> buildFST(List<String> words, boolean allowArrayArcs) throws Exception { + final Outputs<Object> outputs = NoOutputs.getSingleton(); + final Builder<Object> b = new Builder<>(FST.INPUT_TYPE.BYTE1, 0, 0, true, true, Integer.MAX_VALUE, outputs, allowArrayArcs, 15); + + for (String word : words) { + b.add(Util.toIntsRef(new BytesRef(word), new IntsRefBuilder()), outputs.getNoOutput()); + } + return b.finish(); + } + + private List<String> createRandomDictionary(int width, int depth) { + return createRandomDictionary(new ArrayList<>(), new StringBuilder(), width, depth); + } + + private List<String> createRandomDictionary(List<String> dict, StringBuilder buf, int width, int depth) { + char c = (char) random().nextInt(128); + assert width < Character.MIN_SURROGATE / 8 - 128; // avoid surrogate chars + int len = buf.length(); + for (int i = 0; i < width; i++) { + buf.append(c); + if (depth > 0) { + createRandomDictionary(dict, buf, width, depth - 1); + } else { + dict.add(buf.toString()); + } + c += random().nextInt(8); + buf.setLength(len); + } + return dict; + } + +}