From aa81ba8642a57181a4eaa017b52d0d3c3462544b Mon Sep 17 00:00:00 2001
From: Adrien Grand <jpountz@gmail.com>
Date: Fri, 29 Apr 2016 10:35:27 +0200
Subject: [PATCH] LUCENE-7264: Fewer conditionals in DocIdSetBuilder.

---
 lucene/CHANGES.txt                            |  4 +-
 .../simpletext/SimpleTextBKDReader.java       |  2 +
 .../apache/lucene/search/PointInSetQuery.java | 14 ++--
 .../apache/lucene/search/PointRangeQuery.java |  8 +-
 .../apache/lucene/util/DocIdSetBuilder.java   | 73 +++++++++++--------
 .../lucene/search/TestReqExclBulkScorer.java  |  6 +-
 .../lucene/util/TestDocIdSetBuilder.java      | 13 ++--
 ...GeoPointTermQueryConstantScoreWrapper.java |  8 +-
 .../PointInShapeIntersectVisitor.java         |  7 +-
 .../asserting/AssertingPointsFormat.java      |  5 ++
 10 files changed, 86 insertions(+), 54 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 4739a5fba80..a23a54779ca 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -78,8 +78,8 @@ Optimizations
 * LUCENE-7237: LRUQueryCache now prefers returning an uncached Scorer than
   waiting on a lock. (Adrien Grand)
 
-* LUCENE-7261: Speed up LSBRadixSorter (which is used by TermsQuery, multi-term
-  queries and point queries). (Adrien Grand)
+* LUCENE-7261, LUCENE-7264: Speed up DocIdSetBuilder (which is used by
+  TermsQuery, multi-term queries and point queries). (Adrien Grand)
 
 Bug Fixes
 
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java
index 6e2e1ac815a..35e94489472 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java
@@ -44,6 +44,7 @@ class SimpleTextBKDReader extends BKDReader {
     in.seek(blockFP);
     readLine(in, scratch);
     int count = parseInt(scratch, BLOCK_COUNT);
+    visitor.grow(count);
     for(int i=0;i<count;i++) {
       readLine(in, scratch);
       visitor.visit(parseInt(scratch, BLOCK_DOC_ID));
@@ -65,6 +66,7 @@ class SimpleTextBKDReader extends BKDReader {
 
   @Override
   protected void visitDocValues(int[] commonPrefixLengths, byte[] scratchPackedValue, IndexInput in, int[] docIDs, int count, IntersectVisitor visitor) throws IOException {
+    visitor.grow(count);
     // NOTE: we don't do prefix coding, so we ignore commonPrefixLengths
     assert scratchPackedValue.length == packedBytesLength;
     BytesRefBuilder scratch = new BytesRefBuilder();
diff --git a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java
index 3314e29d49f..9be794cccc7 100644
--- a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java
@@ -163,6 +163,7 @@ public abstract class PointInSetQuery extends Query {
     private BytesRef nextQueryPoint;
     private final BytesRef scratch = new BytesRef();
     private final PrefixCodedTerms sortedPackedPoints;
+    private DocIdSetBuilder.BulkAdder adder;
 
     public MergePointVisitor(PrefixCodedTerms sortedPackedPoints, DocIdSetBuilder result) throws IOException {
       this.result = result;
@@ -174,12 +175,12 @@ public abstract class PointInSetQuery extends Query {
 
     @Override
     public void grow(int count) {
-      result.grow(count);
+      adder = result.grow(count);
     }
 
     @Override
     public void visit(int docID) {
-      result.add(docID);
+      adder.add(docID);
     }
 
     @Override
@@ -189,7 +190,7 @@ public abstract class PointInSetQuery extends Query {
         int cmp = nextQueryPoint.compareTo(scratch);
         if (cmp == 0) {
           // Query point equals index point, so collect and return
-          result.add(docID);
+          adder.add(docID);
           break;
         } else if (cmp < 0) {
           // Query point is before index point, so we move to next query point
@@ -237,6 +238,7 @@ public abstract class PointInSetQuery extends Query {
 
     private final DocIdSetBuilder result;
     private final byte[] pointBytes;
+    private DocIdSetBuilder.BulkAdder adder;
 
     public SinglePointVisitor(DocIdSetBuilder result) {
       this.result = result;
@@ -251,12 +253,12 @@ public abstract class PointInSetQuery extends Query {
 
     @Override
     public void grow(int count) {
-      result.grow(count);
+      adder = result.grow(count);
     }
 
     @Override
     public void visit(int docID) {
-      result.add(docID);
+      adder.add(docID);
     }
 
     @Override
@@ -264,7 +266,7 @@ public abstract class PointInSetQuery extends Query {
       assert packedValue.length == pointBytes.length;
       if (Arrays.equals(packedValue, pointBytes)) {
         // The point for this doc matches the point we are querying on
-        result.add(docID);
+        adder.add(docID);
       }
     }
 
diff --git a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java
index 4991096e7cf..cae2192c21b 100644
--- a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java
@@ -111,14 +111,16 @@ public abstract class PointRangeQuery extends Query {
         values.intersect(field,
             new IntersectVisitor() {
 
+              DocIdSetBuilder.BulkAdder adder;
+
               @Override
               public void grow(int count) {
-                result.grow(count);
+                adder = result.grow(count);
               }
 
               @Override
               public void visit(int docID) {
-                result.add(docID);
+                adder.add(docID);
               }
 
               @Override
@@ -136,7 +138,7 @@ public abstract class PointRangeQuery extends Query {
                 }
 
                 // Doc is in-bounds
-                result.add(docID);
+                adder.add(docID);
               }
 
               @Override
diff --git a/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java b/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java
index f35e584aba6..b85a8f29fbe 100644
--- a/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java
@@ -17,6 +17,7 @@
 package org.apache.lucene.util;
 
 import java.io.IOException;
+import java.util.Arrays;
 
 import org.apache.lucene.search.DocIdSet;
 import org.apache.lucene.search.DocIdSetIterator;
@@ -26,17 +27,50 @@ import org.apache.lucene.util.packed.PackedInts;
  * A builder of {@link DocIdSet}s.  At first it uses a sparse structure to gather
  * documents, and then upgrades to a non-sparse bit set once enough hits match.
  *
+ * To add documents, you first need to call {@link #grow} in order to reserve
+ * space, and then call {@link BulkAdder#add(int)} on the returned
+ * {@link BulkAdder}.
+ *
  * @lucene.internal
  */
 public final class DocIdSetBuilder {
 
+  /** Utility class to efficiently add many docs in one go.
+   *  @see DocIdSetBuilder#grow */
+  public static abstract class BulkAdder {
+    public abstract void add(int doc);
+  }
+
+  private static class FixedBitSetAdder extends BulkAdder {
+    final FixedBitSet bitSet;
+
+    FixedBitSetAdder(FixedBitSet bitSet) {
+      this.bitSet = bitSet;
+    }
+
+    @Override
+    public void add(int doc) {
+      bitSet.set(doc);
+    }
+  }
+
+  private class BufferAdder extends BulkAdder {
+
+    @Override
+    public void add(int doc) {
+      buffer[bufferSize++] = doc;
+    }
+
+  }
+
   private final int maxDoc;
   private final int threshold;
 
   private int[] buffer;
   private int bufferSize;
 
-  private BitSet bitSet;
+  private FixedBitSet bitSet;
+  private BulkAdder adder = new BufferAdder();
 
   /**
    * Create a builder that can contain doc IDs between {@code 0} and {@code maxDoc}.
@@ -62,6 +96,7 @@ public final class DocIdSetBuilder {
     }
     this.buffer = null;
     this.bufferSize = 0;
+    this.adder = new FixedBitSetAdder(bitSet);
   }
 
   /** Grows the buffer to at least minSize, but never larger than threshold. */
@@ -69,9 +104,7 @@ public final class DocIdSetBuilder {
     assert minSize < threshold;
     if (buffer.length < minSize) {
       int nextSize = Math.min(threshold, ArrayUtil.oversize(minSize, Integer.BYTES));
-      int[] newBuffer = new int[nextSize];
-      System.arraycopy(buffer, 0, newBuffer, 0, buffer.length);
-      buffer = newBuffer;
+      buffer = Arrays.copyOf(buffer, nextSize);
     }
   }
 
@@ -86,7 +119,7 @@ public final class DocIdSetBuilder {
     if (bitSet != null) {
       bitSet.or(iter);
     } else {
-      while (true) {  
+      while (true) {
         assert buffer.length <= threshold;
         final int end = buffer.length;
         for (int i = bufferSize; i < end; ++i) {
@@ -114,39 +147,19 @@ public final class DocIdSetBuilder {
   }
 
   /**
-   * Reserve space so that this builder can hold {@code numDocs} MORE documents.
+   * Reserve space and return a {@link BulkAdder} object that can be used to
+   * add up to {@code numDocs} documents.
    */
-  public void grow(int numDocs) {
+  public BulkAdder grow(int numDocs) {
     if (bitSet == null) {
-      final long newLength = bufferSize + numDocs;
+      final long newLength = (long) bufferSize + numDocs;
       if (newLength < threshold) {
         growBuffer((int) newLength);
       } else {
         upgradeToBitSet();
       }
     }
-  }
-
-  /**
-   * Add a document to this builder.
-   * NOTE: doc IDs do not need to be provided in order.
-   * NOTE: if you plan on adding several docs at once, look into using
-   * {@link #grow(int)} to reserve space.
-   */
-  public void add(int doc) {
-    if (bitSet != null) {
-      bitSet.set(doc);
-    } else {
-      if (bufferSize + 1 > buffer.length) {
-        if (bufferSize + 1 >= threshold) {
-          upgradeToBitSet();
-          bitSet.set(doc);
-          return;
-        }
-        growBuffer(bufferSize+1);
-      }
-      buffer[bufferSize++] = doc;
-    }
+    return adder;
   }
 
   private static int dedup(int[] arr, int length) {
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestReqExclBulkScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestReqExclBulkScorer.java
index 20917dc13ca..bbc4740e454 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestReqExclBulkScorer.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestReqExclBulkScorer.java
@@ -40,11 +40,13 @@ public class TestReqExclBulkScorer extends LuceneTestCase {
     DocIdSetBuilder exclBuilder = new DocIdSetBuilder(maxDoc);
     final int numIncludedDocs = TestUtil.nextInt(random(), 1, maxDoc);
     final int numExcludedDocs = TestUtil.nextInt(random(), 1, maxDoc);
+    DocIdSetBuilder.BulkAdder reqAdder = reqBuilder.grow(numIncludedDocs);
     for (int i = 0; i < numIncludedDocs; ++i) {
-      reqBuilder.add(random().nextInt(maxDoc));
+      reqAdder.add(random().nextInt(maxDoc));
     }
+    DocIdSetBuilder.BulkAdder exclAdder = exclBuilder.grow(numIncludedDocs);
     for (int i = 0; i < numExcludedDocs; ++i) {
-      exclBuilder.add(random().nextInt(maxDoc));
+      exclAdder.add(random().nextInt(maxDoc));
     }
 
     final DocIdSet req = reqBuilder.build();
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java b/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java
index 97afe8b1854..8814be1f359 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java
@@ -122,11 +122,14 @@ public class TestDocIdSetBuilder extends LuceneTestCase {
       DocIdSetBuilder builder = new DocIdSetBuilder(maxDoc);
       for (j = 0; j < array.length; ) {
         final int l = TestUtil.nextInt(random(), 1, array.length - j);
-        if (rarely()) {
-          builder.grow(l);
-        }
-        for (int k = 0; k < l; ++k) {
-          builder.add(array[j++]);
+        DocIdSetBuilder.BulkAdder adder = null;
+        for (int k = 0, budget = 0; k < l; ++k) {
+          if (budget == 0 || rarely()) {
+            budget = TestUtil.nextInt(random(), 1, l - k + 5);
+            adder = builder.grow(budget);
+          }
+          adder.add(array[j++]);
+          budget--;
         }
       }
 
diff --git a/lucene/spatial/src/java/org/apache/lucene/spatial/geopoint/search/GeoPointTermQueryConstantScoreWrapper.java b/lucene/spatial/src/java/org/apache/lucene/spatial/geopoint/search/GeoPointTermQueryConstantScoreWrapper.java
index 96e0bd961cb..fb467e6317d 100644
--- a/lucene/spatial/src/java/org/apache/lucene/spatial/geopoint/search/GeoPointTermQueryConstantScoreWrapper.java
+++ b/lucene/spatial/src/java/org/apache/lucene/spatial/geopoint/search/GeoPointTermQueryConstantScoreWrapper.java
@@ -110,9 +110,11 @@ final class GeoPointTermQueryConstantScoreWrapper <Q extends GeoPointMultiTermQu
           if (termsEnum.boundaryTerm()) {
             builder.add(docs);
           } else {
-            int docId;
-            while ((docId = docs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
-              builder.add(docId);
+            int numDocs = termsEnum.docFreq();
+            DocIdSetBuilder.BulkAdder adder = builder.grow(numDocs);
+            for (int i = 0; i < numDocs; ++i) {
+              int docId = docs.nextDoc();
+              adder.add(docId);
               preApproved.set(docId);
             }
           }
diff --git a/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/PointInShapeIntersectVisitor.java b/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/PointInShapeIntersectVisitor.java
index cf94c35b2c9..77551090e51 100644
--- a/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/PointInShapeIntersectVisitor.java
+++ b/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/PointInShapeIntersectVisitor.java
@@ -31,6 +31,7 @@ class PointInShapeIntersectVisitor implements IntersectVisitor {
   private final DocIdSetBuilder hits;
   private final GeoShape shape;
   private final XYZBounds shapeBounds;
+  private DocIdSetBuilder.BulkAdder adder;
   
   public PointInShapeIntersectVisitor(DocIdSetBuilder hits, GeoShape shape, XYZBounds shapeBounds) {
     this.hits = hits;
@@ -40,12 +41,12 @@ class PointInShapeIntersectVisitor implements IntersectVisitor {
 
   @Override
   public void grow(int count) {
-    hits.grow(count);
+    adder = hits.grow(count);
   }
 
   @Override
   public void visit(int docID) {
-    hits.add(docID);
+    adder.add(docID);
   }
 
   @Override
@@ -58,7 +59,7 @@ class PointInShapeIntersectVisitor implements IntersectVisitor {
       y >= shapeBounds.getMinimumY() && y <= shapeBounds.getMaximumY() &&
       z >= shapeBounds.getMinimumZ() && z <= shapeBounds.getMaximumZ()) {
       if (shape.isWithin(x, y, z)) {
-        hits.add(docID);
+        adder.add(docID);
       }
     }
   }
diff --git a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingPointsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingPointsFormat.java
index 061d5b60a5c..b6729135435 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingPointsFormat.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingPointsFormat.java
@@ -77,6 +77,7 @@ public final class AssertingPointsFormat extends PointsFormat {
     final byte[] lastMaxPackedValue;
     private Relation lastCompareResult;
     private int lastDocID = -1;
+    private int docBudget;
 
     public AssertingIntersectVisitor(int numDims, int bytesPerDim, IntersectVisitor in) {
       this.in = in;
@@ -93,6 +94,8 @@ public final class AssertingPointsFormat extends PointsFormat {
 
     @Override
     public void visit(int docID) throws IOException {
+      assert --docBudget >= 0 : "called add() more times than the last call to grow() reserved";
+
       // This method, not filtering each hit, should only be invoked when the cell is inside the query shape:
       assert lastCompareResult == Relation.CELL_INSIDE_QUERY;
       in.visit(docID);
@@ -100,6 +103,7 @@ public final class AssertingPointsFormat extends PointsFormat {
 
     @Override
     public void visit(int docID, byte[] packedValue) throws IOException {
+      assert --docBudget >= 0 : "called add() more times than the last call to grow() reserved";
 
       // This method, to filter each doc's value, should only be invoked when the cell crosses the query shape:
       assert lastCompareResult == PointValues.Relation.CELL_CROSSES_QUERY;
@@ -130,6 +134,7 @@ public final class AssertingPointsFormat extends PointsFormat {
     @Override
     public void grow(int count) {
       in.grow(count);
+      docBudget = count;
     }
 
     @Override