From 5eb575f8abb7e0d7abe79b227d581d2fe40fd0a8 Mon Sep 17 00:00:00 2001 From: gf2121 <52390227+gf2121@users.noreply.github.com> Date: Wed, 1 Dec 2021 22:31:05 +0800 Subject: [PATCH] LUCENE-10233: Store docIds as bitset to speed up addAll (#438) --- lucene/CHANGES.txt | 3 + .../document/LatLonPointDistanceQuery.java | 5 + .../lucene/document/RangeFieldQuery.java | 5 + .../apache/lucene/document/SpatialQuery.java | 5 + .../document/XYPointInGeometryQuery.java | 5 + .../org/apache/lucene/index/PointValues.java | 11 +++ .../apache/lucene/search/PointInSetQuery.java | 10 ++ .../apache/lucene/search/PointRangeQuery.java | 5 + .../lucene/util/DocBaseBitSetIterator.java | 92 +++++++++++++++++++ .../apache/lucene/util/DocIdSetBuilder.java | 12 +++ .../org/apache/lucene/util/FixedBitSet.java | 19 +++- .../apache/lucene/util/bkd/DocIdsWriter.java | 84 ++++++++++++++++- .../lucene/util/bkd/TestDocIdsWriter.java | 18 ++++ .../sandbox/search/MultiRangeQuery.java | 5 + .../PointInShapeIntersectVisitor.java | 7 ++ 15 files changed, 279 insertions(+), 7 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/util/DocBaseBitSetIterator.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 5024f6c9e44..819e70c8c08 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -73,6 +73,9 @@ Improvements Optimizations --------------------- +* LUCENE-10233: Store BKD leaves' doc IDs as bitset in some cases (typically for low cardinality fields + or sorted indices) to speed up addAll. (Guo Feng, Adrien Grand) + * LUCENE-10225: Improve IntroSelector with 3-ways partitioning. (Bruno Roustant, Adrien Grand) Bug Fixes diff --git a/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java b/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java index 06a2b9e7310..b1950e5b366 100644 --- a/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java @@ -256,6 +256,11 @@ final class LatLonPointDistanceQuery extends Query { adder.add(docID); } + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } + @Override public void visit(int docID, byte[] packedValue) { if (matches(packedValue)) { diff --git a/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java b/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java index 0d0edbb7f4d..ab83969de6f 100644 --- a/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java @@ -466,6 +466,11 @@ public abstract class RangeFieldQuery extends Query { adder.add(docID); } + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } + @Override public void visit(int docID, byte[] leaf) throws IOException { if (queryType.matches(ranges, leaf, numDims, bytesPerDim)) { diff --git a/lucene/core/src/java/org/apache/lucene/document/SpatialQuery.java b/lucene/core/src/java/org/apache/lucene/document/SpatialQuery.java index 1d6f98114bd..19a5d83bcb8 100644 --- a/lucene/core/src/java/org/apache/lucene/document/SpatialQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/SpatialQuery.java @@ -422,6 +422,11 @@ abstract class SpatialQuery extends Query { adder.add(docID); } + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } + @Override public void visit(int docID, byte[] t) { if (leafPredicate.test(t)) { diff --git a/lucene/core/src/java/org/apache/lucene/document/XYPointInGeometryQuery.java b/lucene/core/src/java/org/apache/lucene/document/XYPointInGeometryQuery.java index 3260845ba53..1533463a273 100644 --- a/lucene/core/src/java/org/apache/lucene/document/XYPointInGeometryQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/XYPointInGeometryQuery.java @@ -85,6 +85,11 @@ final class XYPointInGeometryQuery extends Query { adder.add(docID); } + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } + @Override public void visit(int docID, byte[] packedValue) { double x = XYEncodingUtils.decode(packedValue, 0); diff --git a/lucene/core/src/java/org/apache/lucene/index/PointValues.java b/lucene/core/src/java/org/apache/lucene/index/PointValues.java index 45416459786..64229d18936 100644 --- a/lucene/core/src/java/org/apache/lucene/index/PointValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/PointValues.java @@ -287,6 +287,17 @@ public abstract class PointValues { */ void visit(int docID) throws IOException; + /** + * Similar to {@link IntersectVisitor#visit(int)}, but a bulk visit and implements may have + * their optimizations. + */ + default void visit(DocIdSetIterator iterator) throws IOException { + int docID; + while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + visit(docID); + } + } + /** * Called for all documents in a leaf cell that crosses the query. The consumer should * scrutinize the packedValue to decide whether to accept it. In the 1D case, values are visited diff --git a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java index e85b7f18677..a49072794dd 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java @@ -237,6 +237,11 @@ public abstract class PointInSetQuery extends Query implements Accountable { adder.add(docID); } + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } + @Override public void visit(int docID, byte[] packedValue) { if (matches(packedValue)) { @@ -336,6 +341,11 @@ public abstract class PointInSetQuery extends Query implements Accountable { adder.add(docID); } + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } + @Override public void visit(int docID, byte[] packedValue) { assert packedValue.length == pointBytes.length; diff --git a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java index 6e69c748407..06b7836e4be 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java @@ -178,6 +178,11 @@ public abstract class PointRangeQuery extends Query { adder.add(docID); } + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } + @Override public void visit(int docID, byte[] packedValue) { if (matches(packedValue)) { diff --git a/lucene/core/src/java/org/apache/lucene/util/DocBaseBitSetIterator.java b/lucene/core/src/java/org/apache/lucene/util/DocBaseBitSetIterator.java new file mode 100644 index 00000000000..ba05f94c579 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/DocBaseBitSetIterator.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import org.apache.lucene.search.DocIdSetIterator; + +/** + * A {@link DocIdSetIterator} like {@link BitSetIterator} but has a doc base in onder to avoid + * storing previous 0s. + */ +public class DocBaseBitSetIterator extends DocIdSetIterator { + + private final FixedBitSet bits; + private final int length; + private final long cost; + private final int docBase; + private int doc = -1; + + public DocBaseBitSetIterator(FixedBitSet bits, long cost, int docBase) { + if (cost < 0) { + throw new IllegalArgumentException("cost must be >= 0, got " + cost); + } + if ((docBase & 63) != 0) { + throw new IllegalArgumentException("docBase need to be a multiple of 64"); + } + this.bits = bits; + this.length = bits.length() + docBase; + this.cost = cost; + this.docBase = docBase; + } + + /** + * Get the {@link FixedBitSet}. A docId will exist in this {@link DocIdSetIterator} if the bitset + * contains the (docId - {@link #getDocBase}) + * + * @return the offset docId bitset + */ + public FixedBitSet getBitSet() { + return bits; + } + + @Override + public int docID() { + return doc; + } + + /** + * Get the docBase. It is guaranteed that docBase is a multiple of 64. + * + * @return the docBase + */ + public int getDocBase() { + return docBase; + } + + @Override + public int nextDoc() { + return advance(doc + 1); + } + + @Override + public int advance(int target) { + if (target >= length) { + return doc = NO_MORE_DOCS; + } + int next = bits.nextSetBit(Math.max(0, target - docBase)); + if (next == NO_MORE_DOCS) { + return doc = NO_MORE_DOCS; + } else { + return doc = next + docBase; + } + } + + @Override + public long cost() { + return cost; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java b/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java index 0abf92eba14..67b3dde9f20 100644 --- a/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java @@ -43,6 +43,13 @@ public final class DocIdSetBuilder { */ public abstract static class BulkAdder { public abstract void add(int doc); + + public void add(DocIdSetIterator iterator) throws IOException { + int docID; + while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + add(docID); + } + } } private static class FixedBitSetAdder extends BulkAdder { @@ -56,6 +63,11 @@ public final class DocIdSetBuilder { public void add(int doc) { bitSet.set(doc); } + + @Override + public void add(DocIdSetIterator iterator) throws IOException { + bitSet.or(iterator); + } } private static class Buffer { diff --git a/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java b/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java index 1e914949202..d30a3147425 100644 --- a/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java +++ b/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java @@ -269,6 +269,10 @@ public final class FixedBitSet extends BitSet { checkUnpositioned(iter); final FixedBitSet bits = BitSetIterator.getFixedBitSetOrNull(iter); or(bits); + } else if (iter instanceof DocBaseBitSetIterator) { + checkUnpositioned(iter); + DocBaseBitSetIterator baseIter = (DocBaseBitSetIterator) iter; + or(baseIter.getDocBase() >> 6, baseIter.getBitSet()); } else { super.or(iter); } @@ -276,15 +280,20 @@ public final class FixedBitSet extends BitSet { /** this = this OR other */ public void or(FixedBitSet other) { - or(other.bits, other.numWords); + or(0, other.bits, other.numWords); } - private void or(final long[] otherArr, final int otherNumWords) { - assert otherNumWords <= numWords : "numWords=" + numWords + ", otherNumWords=" + otherNumWords; + private void or(final int otherOffsetWords, FixedBitSet other) { + or(otherOffsetWords, other.bits, other.numWords); + } + + private void or(final int otherOffsetWords, final long[] otherArr, final int otherNumWords) { + assert otherNumWords + otherOffsetWords <= numWords + : "numWords=" + numWords + ", otherNumWords=" + otherNumWords; + int pos = Math.min(numWords - otherOffsetWords, otherNumWords); final long[] thisArr = this.bits; - int pos = Math.min(numWords, otherNumWords); while (--pos >= 0) { - thisArr[pos] |= otherArr[pos]; + thisArr[pos + otherOffsetWords] |= otherArr[pos]; } } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/DocIdsWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/DocIdsWriter.java index 6be3c43a1fa..def705b708a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/DocIdsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/DocIdsWriter.java @@ -18,8 +18,11 @@ package org.apache.lucene.util.bkd; import java.io.IOException; import org.apache.lucene.index.PointValues.IntersectVisitor; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.DocBaseBitSetIterator; +import org.apache.lucene.util.FixedBitSet; class DocIdsWriter { @@ -29,12 +32,26 @@ class DocIdsWriter { // docs can be sorted either when all docs in a block have the same value // or when a segment is sorted boolean sorted = true; + boolean strictlySorted = true; for (int i = 1; i < count; ++i) { - if (docIds[start + i - 1] > docIds[start + i]) { - sorted = false; + int last = docIds[start + i - 1]; + int current = docIds[start + i]; + if (last > current) { + sorted = strictlySorted = false; break; + } else if (last == current) { + strictlySorted = false; } } + + if (strictlySorted && (docIds[start + count - 1] - docIds[start] + 1) <= (count << 4)) { + // Only trigger this optimization when max - min + 1 <= 16 * count in order to avoid expanding + // too much storage. + // A field with lower cardinality will have higher probability to trigger this optimization. + out.writeByte((byte) -1); + writeIdsAsBitSet(docIds, start, count, out); + return; + } if (sorted) { out.writeByte((byte) 0); int previous = 0; @@ -85,10 +102,46 @@ class DocIdsWriter { } } + private static void writeIdsAsBitSet(int[] docIds, int start, int count, DataOutput out) + throws IOException { + int min = docIds[start]; + int max = docIds[start + count - 1]; + + final int offsetWords = min >> 6; + final int offsetBits = offsetWords << 6; + final int totalWordCount = FixedBitSet.bits2words(max - offsetBits + 1); + long currentWord = 0; + int currentWordIndex = 0; + + out.writeVInt(offsetWords); + out.writeVInt(totalWordCount); + // build bit set streaming + for (int i = 0; i < count; i++) { + final int index = docIds[start + i] - offsetBits; + final int nextWordIndex = index >> 6; + assert currentWordIndex <= nextWordIndex; + if (currentWordIndex < nextWordIndex) { + out.writeLong(currentWord); + currentWord = 0L; + currentWordIndex++; + while (currentWordIndex < nextWordIndex) { + currentWordIndex++; + out.writeLong(0L); + } + } + currentWord |= 1L << index; + } + out.writeLong(currentWord); + assert currentWordIndex + 1 == totalWordCount; + } + /** Read {@code count} integers into {@code docIDs}. */ static void readInts(IndexInput in, int count, int[] docIDs) throws IOException { final int bpv = in.readByte(); switch (bpv) { + case -1: + readBitSet(in, count, docIDs); + break; case 0: readDeltaVInts(in, count, docIDs); break; @@ -103,6 +156,24 @@ class DocIdsWriter { } } + private static DocIdSetIterator readBitSetIterator(IndexInput in, int count) throws IOException { + int offsetWords = in.readVInt(); + int longLen = in.readVInt(); + long[] bits = new long[longLen]; + in.readLongs(bits, 0, longLen); + FixedBitSet bitSet = new FixedBitSet(bits, longLen << 6); + return new DocBaseBitSetIterator(bitSet, count, offsetWords << 6); + } + + private static void readBitSet(IndexInput in, int count, int[] docIDs) throws IOException { + DocIdSetIterator iterator = readBitSetIterator(in, count); + int docId, pos = 0; + while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + docIDs[pos++] = docId; + } + assert pos == count : "pos: " + pos + "count: " + count; + } + private static void readDeltaVInts(IndexInput in, int count, int[] docIDs) throws IOException { int doc = 0; for (int i = 0; i < count; i++) { @@ -144,6 +215,9 @@ class DocIdsWriter { static void readInts(IndexInput in, int count, IntersectVisitor visitor) throws IOException { final int bpv = in.readByte(); switch (bpv) { + case -1: + readBitSet(in, count, visitor); + break; case 0: readDeltaVInts(in, count, visitor); break; @@ -194,4 +268,10 @@ class DocIdsWriter { visitor.visit((Short.toUnsignedInt(in.readShort()) << 8) | Byte.toUnsignedInt(in.readByte())); } } + + private static void readBitSet(IndexInput in, int count, IntersectVisitor visitor) + throws IOException { + DocIdSetIterator bitSetIterator = readBitSetIterator(in, count); + visitor.visit(bitSetIterator); + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/TestDocIdsWriter.java b/lucene/core/src/test/org/apache/lucene/util/bkd/TestDocIdsWriter.java index e70de931034..247329461c3 100644 --- a/lucene/core/src/test/org/apache/lucene/util/bkd/TestDocIdsWriter.java +++ b/lucene/core/src/test/org/apache/lucene/util/bkd/TestDocIdsWriter.java @@ -18,6 +18,8 @@ package org.apache.lucene.util.bkd; import java.io.IOException; import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; import org.apache.lucene.index.PointValues.IntersectVisitor; import org.apache.lucene.index.PointValues.Relation; import org.apache.lucene.store.Directory; @@ -58,6 +60,22 @@ public class TestDocIdsWriter extends LuceneTestCase { } } + public void testBitSet() throws Exception { + int numIters = atLeast(100); + try (Directory dir = newDirectory()) { + for (int iter = 0; iter < numIters; ++iter) { + int size = random().nextInt(5000); + Set set = new HashSet<>(size); + int small = random().nextInt(1000); + while (set.size() < size) { + set.add(small + random().nextInt(size * 16)); + } + int[] docIDs = set.stream().mapToInt(t -> t).sorted().toArray(); + test(dir, docIDs); + } + } + } + private void test(Directory dir, int[] ints) throws Exception { final long len; try (IndexOutput out = dir.createOutput("tmp", IOContext.DEFAULT)) { diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java index f0583bd2761..a5ba6c225bc 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java @@ -171,6 +171,11 @@ public abstract class MultiRangeQuery extends Query { adder.add(docID); } + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } + @Override public void visit(int docID, byte[] packedValue) { // If a single OR clause has the value in range, the entire query accepts the value diff --git a/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/PointInShapeIntersectVisitor.java b/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/PointInShapeIntersectVisitor.java index f33ee4fabed..8883fef2240 100644 --- a/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/PointInShapeIntersectVisitor.java +++ b/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/PointInShapeIntersectVisitor.java @@ -17,8 +17,10 @@ package org.apache.lucene.spatial3d; +import java.io.IOException; import org.apache.lucene.index.PointValues.IntersectVisitor; import org.apache.lucene.index.PointValues.Relation; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.spatial3d.geom.GeoArea; import org.apache.lucene.spatial3d.geom.GeoAreaFactory; import org.apache.lucene.spatial3d.geom.GeoShape; @@ -60,6 +62,11 @@ class PointInShapeIntersectVisitor implements IntersectVisitor { adder.add(docID); } + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } + @Override public void visit(int docID, byte[] packedValue) { assert packedValue.length == 12;