From 92d4e712d5d50d745c5a6c10dacda66198974116 Mon Sep 17 00:00:00 2001
From: Michael Sokolov <sokolov@apache.org>
Date: Thu, 4 Jul 2019 10:45:17 -0400
Subject: [PATCH] LUCENE-8920: refactor FST binary search

---
 .../codecs/memory/FSTOrdTermsReader.java      |   2 +-
 .../lucene/util/fst/BytesRefFSTEnum.java      |   2 +-
 .../org/apache/lucene/util/fst/FSTEnum.java   |  89 ++++----------
 .../lucene/util/fst/IntsRefFSTEnum.java       |   2 +-
 .../java/org/apache/lucene/util/fst/Util.java |  64 ++++++----
 .../org/apache/lucene/util/fst/TestUtil.java  | 115 ++++++++++++++++++
 6 files changed, 180 insertions(+), 94 deletions(-)
 create mode 100644 lucene/core/src/test/org/apache/lucene/util/fst/TestUtil.java

diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/memory/FSTOrdTermsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/memory/FSTOrdTermsReader.java
index daba6096c93..0fa8ebe0524 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/memory/FSTOrdTermsReader.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/memory/FSTOrdTermsReader.java
@@ -305,7 +305,7 @@ public class FSTOrdTermsReader extends FieldsProducer {
     }
 
     // Only wraps common operations for PBF interact
-    abstract class  BaseTermsEnum extends org.apache.lucene.index.BaseTermsEnum {
+    abstract class BaseTermsEnum extends org.apache.lucene.index.BaseTermsEnum {
 
       /* Current term's ord, starts from 0 */
       long ord;
diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/BytesRefFSTEnum.java b/lucene/core/src/java/org/apache/lucene/util/fst/BytesRefFSTEnum.java
index 97d6a0eb357..19e2a01c575 100644
--- a/lucene/core/src/java/org/apache/lucene/util/fst/BytesRefFSTEnum.java
+++ b/lucene/core/src/java/org/apache/lucene/util/fst/BytesRefFSTEnum.java
@@ -81,7 +81,7 @@ public final class BytesRefFSTEnum<T> extends FSTEnum<T> {
   public InputOutput<T> seekExact(BytesRef target) throws IOException {
     this.target = target;
     targetLength = target.length;
-    if (super.doSeekExact()) {
+    if (doSeekExact()) {
       assert upto == 1+target.length;
       return setResult();
     } else {
diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java b/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java
index feeddf3fb6b..36d8ddd384a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java
+++ b/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java
@@ -24,7 +24,7 @@ import org.apache.lucene.util.RamUsageEstimator;
 
 /** Can next() and advance() through the terms in an FST
  *
-  * @lucene.experimental
+ * @lucene.experimental
 */
 
 abstract class FSTEnum<T> {
@@ -207,36 +207,12 @@ abstract class FSTEnum<T> {
 
   private FST.Arc<T> doSeekCeilArrayPacked(final FST.Arc<T> arc, final int targetLabel, final FST.BytesReader in) throws IOException {
     // The array is packed -- use binary search to find the target.
-
-    int low = arc.arcIdx();
-    int high = arc.numArcs() -1;
-    int mid = 0;
-    //System.out.println("do arc array low=" + low + " high=" + high + " targetLabel=" + targetLabel);
-    boolean found = false;
-    while (low <= high) {
-      mid = (low + high) >>> 1;
-      in.setPosition(arc.posArcsStart());
-      in.skipBytes(arc.bytesPerArc() * mid + 1);
-      final int midLabel = fst.readLabel(in);
-      final int cmp = midLabel - targetLabel;
-      //System.out.println("  cycle low=" + low + " high=" + high + " mid=" + mid + " midLabel=" + midLabel + " cmp=" + cmp);
-      if (cmp < 0)
-        low = mid + 1;
-      else if (cmp > 0)
-        high = mid - 1;
-      else {
-        found = true;
-        break;
-      }
-    }
-
-    // NOTE: this code is dup'd w/ the code below (in
-    // the outer else clause):
-    if (found) {
+    int idx = Util.binarySearch(fst, arc, targetLabel);
+    if (idx >= 0) {
       // Match
-      fst.readArcByIndex(arc, in, mid);
-      assert arc.arcIdx() == mid;
-      assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + mid;
+      fst.readArcByIndex(arc, in, idx);
+      assert arc.arcIdx() == idx;
+      assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + idx;
       output[upto] = fst.outputs.add(output[upto-1], arc.output());
       if (targetLabel == FST.END_LABEL) {
         return null;
@@ -244,9 +220,11 @@ abstract class FSTEnum<T> {
       setCurrentLabel(arc.label());
       incr();
       return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
-    } else if (low == arc.numArcs()) {
+    }
+    idx = -1 - idx;
+    if (idx == arc.numArcs()) {
       // Dead end
-      fst.readArcByIndex(arc, in, arc.numArcs() - 1);
+      fst.readArcByIndex(arc, in, idx - 1);
       assert arc.isLast();
       // Dead end (target is after the last arc);
       // rollback to last fork then push
@@ -265,7 +243,8 @@ abstract class FSTEnum<T> {
         upto--;
       }
     } else {
-      fst.readArcByIndex(arc, in, low);
+      // Ceiling - arc with least higher label
+      fst.readArcByIndex(arc, in, idx);
       assert arc.label() > targetLabel;
       pushFirst();
       return null;
@@ -314,7 +293,7 @@ abstract class FSTEnum<T> {
   // Todo: should we return a status here (SEEK_FOUND / SEEK_NOT_FOUND /
   // SEEK_END)?  saves the eq check above?
   /** Seeks to largest term that's &lt;= target. */
-  protected void doSeekFloor() throws IOException {
+  void doSeekFloor() throws IOException {
 
     // TODO: possibly caller could/should provide common
     // prefix length?  ie this work may be redundant if
@@ -417,37 +396,14 @@ abstract class FSTEnum<T> {
 
   private FST.Arc<T> doSeekFloorArrayPacked(FST.Arc<T> arc, int targetLabel, final FST.BytesReader in) throws IOException {
     // Arcs are fixed array -- use binary search to find the target.
+    int idx = Util.binarySearch(fst, arc, targetLabel);
 
-    int low = arc.arcIdx();
-    int high = arc.numArcs() -1;
-    int mid = 0;
-    //System.out.println("do arc array low=" + low + " high=" + high + " targetLabel=" + targetLabel);
-    boolean found = false;
-    while (low <= high) {
-      mid = (low + high) >>> 1;
-      in.setPosition(arc.posArcsStart());
-      in.skipBytes(arc.bytesPerArc() * mid + 1);
-      final int midLabel = fst.readLabel(in);
-      final int cmp = midLabel - targetLabel;
-      //System.out.println("  cycle low=" + low + " high=" + high + " mid=" + mid + " midLabel=" + midLabel + " cmp=" + cmp);
-      if (cmp < 0) {
-        low = mid + 1;
-      } else if (cmp > 0) {
-        high = mid - 1;
-      } else {
-        found = true;
-        break;
-      }
-    }
-
-    // NOTE: this code is dup'd w/ the code below (in
-    // the outer else clause):
-    if (found) {
+    if (idx >= 0) {
       // Match -- recurse
-      //System.out.println("  match!  arcIdx=" + mid);
-      fst.readArcByIndex(arc, in, mid);
-      assert arc.arcIdx() == mid;
-      assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + mid;
+      //System.out.println("  match!  arcIdx=" + idx);
+      fst.readArcByIndex(arc, in, idx);
+      assert arc.arcIdx() == idx;
+      assert arc.label() == targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel + " mid=" + idx;
       output[upto] = fst.outputs.add(output[upto-1], arc.output());
       if (targetLabel == FST.END_LABEL) {
         return null;
@@ -455,7 +411,7 @@ abstract class FSTEnum<T> {
       setCurrentLabel(arc.label());
       incr();
       return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
-    } else if (high == -1) {
+    } else if (idx == -1) {
       //System.out.println("  before first");
       // Very first arc is after our target
       // TODO: if each arc could somehow read the arc just
@@ -483,8 +439,8 @@ abstract class FSTEnum<T> {
         arc = getArc(upto);
       }
     } else {
-      // There is a floor arc:
-      fst.readArcByIndex(arc, in, high);
+      // There is a floor arc; idx will be {@code -1 - (floor + 1)}.
+      fst.readArcByIndex(arc, in, -2 - idx);
       assert arc.isLast() || fst.readNextArcLabel(arc, in) > targetLabel;
       assert arc.label() < targetLabel: "arc.label=" + arc.label() + " vs targetLabel=" + targetLabel;
       pushLast();
@@ -652,4 +608,5 @@ abstract class FSTEnum<T> {
     }
     return arcs[idx];
   }
+
 }
diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/IntsRefFSTEnum.java b/lucene/core/src/java/org/apache/lucene/util/fst/IntsRefFSTEnum.java
index f485854a748..2c05c965fbd 100644
--- a/lucene/core/src/java/org/apache/lucene/util/fst/IntsRefFSTEnum.java
+++ b/lucene/core/src/java/org/apache/lucene/util/fst/IntsRefFSTEnum.java
@@ -81,7 +81,7 @@ public final class IntsRefFSTEnum<T> extends FSTEnum<T> {
   public InputOutput<T> seekExact(IntsRef target) throws IOException {
     this.target = target;
     targetLength = target.length;
-    if (super.doSeekExact()) {
+    if (doSeekExact()) {
       assert upto == 1+target.length;
       return setResult();
     } else {
diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/Util.java b/lucene/core/src/java/org/apache/lucene/util/fst/Util.java
index e033267ffbe..ddecaded731 100644
--- a/lucene/core/src/java/org/apache/lucene/util/fst/Util.java
+++ b/lucene/core/src/java/org/apache/lucene/util/fst/Util.java
@@ -951,35 +951,17 @@ public final class Util {
           return fst.readArcAtPosition(arc, in, arc.posArcsStart() - offset * arc.bytesPerArc());
         }
       }
-      // Arcs are packed array -- use binary search to find
-      // the target.
-
-      int low = arc.arcIdx();
-      int mid = 0;
-      int high = arc.numArcs() - 1;
-      // System.out.println("do arc array low=" + low + " high=" + high +
-      // " targetLabel=" + targetLabel);
-      while (low <= high) {
-        mid = (low + high) >>> 1;
-        in.setPosition(arc.posArcsStart());
-        in.skipBytes(arc.bytesPerArc() * mid + 1);
-        final int midLabel = fst.readLabel(in);
-        final int cmp = midLabel - label;
-        // System.out.println("  cycle low=" + low + " high=" + high + " mid=" +
-        // mid + " midLabel=" + midLabel + " cmp=" + cmp);
-        if (cmp < 0) {
-          low = mid + 1;
-        } else if (cmp > 0) {
-          high = mid - 1;
-        } else {
-          return fst.readArcByIndex(arc, in, mid);
-        }
+      // Arcs are packed array -- use binary search to find the target.
+      int idx = binarySearch(fst, arc, label);
+      if (idx >= 0) {
+        return fst.readArcByIndex(arc, in, idx);
       }
-      if (low == arc.numArcs()) {
+      idx = -1 - idx;
+      if (idx == arc.numArcs()) {
         // DEAD END!
         return null;
       }
-      return fst.readArcByIndex(arc, in , high + 1);
+      return fst.readArcByIndex(arc, in , idx);
     }
 
     // Linear scan
@@ -1001,4 +983,36 @@ public final class Util {
     }
   }
 
+  /**
+   * Perform a binary search of Arcs encoded as a packed array
+   * @param fst the FST from which to read
+   * @param arc the starting arc; sibling arcs greater than this will be searched. Usually the first arc in the array.
+   * @param targetLabel the label to search for
+   * @param <T> the output type of the FST
+   * @return the index of the Arc having the target label, or if no Arc has the matching label, {@code -1 - idx)},
+   * where {@code idx} is the index of the Arc with the next highest label, or the total number of arcs
+   * if the target label exceeds the maximum.
+   * @throws IOException when the FST reader does
+   */
+  static <T> int binarySearch(FST<T> fst, FST.Arc<T> arc, int targetLabel) throws IOException {
+    BytesReader in = fst.getBytesReader();
+    int low = arc.arcIdx();
+    int mid = 0;
+    int high = arc.numArcs() -1;
+    while (low <= high) {
+      mid = (low + high) >>> 1;
+      in.setPosition(arc.posArcsStart());
+      in.skipBytes(arc.bytesPerArc() * mid + 1);
+      final int midLabel = fst.readLabel(in);
+      final int cmp = midLabel - targetLabel;
+      if (cmp < 0) {
+        low = mid + 1;
+      } else if (cmp > 0) {
+        high = mid - 1;
+      } else {
+        return mid;
+      }
+    }
+    return -1 - low;
+  }
 }
diff --git a/lucene/core/src/test/org/apache/lucene/util/fst/TestUtil.java b/lucene/core/src/test/org/apache/lucene/util/fst/TestUtil.java
new file mode 100644
index 00000000000..5ec163e9cfe
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/fst/TestUtil.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util.fst;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.IntsRefBuilder;
+import org.apache.lucene.util.LuceneTestCase;
+
+public class TestUtil extends LuceneTestCase {
+
+  public void testBinarySearch() throws Exception {
+    // Creates a node with 8 arcs spanning (z-A) = 57 chars that will be encoded as a sparse array (no gaps)
+    // requiring binary search
+    List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
+    FST<Object> fst = buildFST(letters, true);
+    FST.Arc<Object> arc = fst.getFirstArc(new FST.Arc<>());
+    arc = fst.readFirstTargetArc(arc, arc, fst.getBytesReader());
+    for (int i = 0; i < letters.size(); i++) {
+      assertEquals(i, Util.binarySearch(fst, arc, letters.get(i).charAt(0)));
+    }
+    // before the first
+    assertEquals(-1, Util.binarySearch(fst, arc, ' '));
+    // after the last
+    assertEquals(-1 - letters.size(), Util.binarySearch(fst, arc, '~'));
+    assertEquals(-2, Util.binarySearch(fst, arc, 'B'));
+    assertEquals(-2, Util.binarySearch(fst, arc, 'C'));
+    assertEquals(-7, Util.binarySearch(fst, arc, 'P'));
+  }
+
+  public void testReadCeilArcPackedArray() throws Exception {
+    List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
+    verifyReadCeilArc(letters, true);
+  }
+
+  public void testReadCeilArcArrayWithGaps() throws Exception {
+    List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T");
+    verifyReadCeilArc(letters, true);
+  }
+
+  public void testReadCeilArcList() throws Exception {
+    List<String> letters = Arrays.asList("A", "E", "J", "K", "L", "O", "T", "z");
+    verifyReadCeilArc(letters, false);
+  }
+
+  private void verifyReadCeilArc(List<String> letters, boolean allowArrayArcs) throws Exception {
+    FST<Object> fst = buildFST(letters, allowArrayArcs);
+    FST.Arc<Object> first = fst.getFirstArc(new FST.Arc<>());
+    FST.Arc<Object> arc = new FST.Arc<>();
+    FST.BytesReader in = fst.getBytesReader();
+    for (String letter : letters) {
+      char c = letter.charAt(0);
+      arc = Util.readCeilArc(c, fst, first, arc, in);
+      assertNotNull(arc);
+      assertEquals(c, arc.label());
+    }
+    // before the first
+    assertEquals('A', Util.readCeilArc(' ', fst, first, arc, in).label());
+    // after the last
+    assertNull(Util.readCeilArc('~', fst, first, arc, in));
+    // in the middle
+    assertEquals('J', Util.readCeilArc('F', fst, first, arc, in).label());
+    // no following arcs
+    assertNull(Util.readCeilArc('Z', fst, arc, arc, in));
+  }
+
+  private FST<Object> buildFST(List<String> words, boolean allowArrayArcs) throws Exception {
+    final Outputs<Object> outputs = NoOutputs.getSingleton();
+    final Builder<Object> b = new Builder<>(FST.INPUT_TYPE.BYTE1, 0, 0, true, true, Integer.MAX_VALUE, outputs, allowArrayArcs, 15);
+
+    for (String word : words) {
+      b.add(Util.toIntsRef(new BytesRef(word), new IntsRefBuilder()), outputs.getNoOutput());
+    }
+    return b.finish();
+  }
+
+  private List<String> createRandomDictionary(int width, int depth) {
+    return createRandomDictionary(new ArrayList<>(), new StringBuilder(), width, depth);
+  }
+
+  private List<String> createRandomDictionary(List<String> dict, StringBuilder buf, int width, int depth) {
+    char c = (char) random().nextInt(128);
+    assert width < Character.MIN_SURROGATE / 8 - 128; // avoid surrogate chars
+    int len = buf.length();
+    for (int i = 0; i < width; i++) {
+      buf.append(c);
+      if (depth > 0) {
+        createRandomDictionary(dict, buf, width, depth - 1);
+      } else {
+        dict.add(buf.toString());
+      }
+      c += random().nextInt(8);
+      buf.setLength(len);
+    }
+    return dict;
+  }
+
+}