LUCENE-10233: Store docIds as bitset to speed up addAll (#438)

This commit is contained in:
gf2121 2021-12-01 22:31:05 +08:00 committed by GitHub
parent a7ebf6618c
commit 5eb575f8ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 279 additions and 7 deletions

View File

@ -73,6 +73,9 @@ Improvements
Optimizations
---------------------
* LUCENE-10233: Store BKD leaves' doc IDs as bitset in some cases (typically for low cardinality fields
or sorted indices) to speed up addAll. (Guo Feng, Adrien Grand)
* LUCENE-10225: Improve IntroSelector with 3-ways partitioning. (Bruno Roustant, Adrien Grand)
Bug Fixes

View File

@ -256,6 +256,11 @@ final class LatLonPointDistanceQuery extends Query {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(int docID, byte[] packedValue) {
if (matches(packedValue)) {

View File

@ -466,6 +466,11 @@ public abstract class RangeFieldQuery extends Query {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(int docID, byte[] leaf) throws IOException {
if (queryType.matches(ranges, leaf, numDims, bytesPerDim)) {

View File

@ -422,6 +422,11 @@ abstract class SpatialQuery extends Query {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(int docID, byte[] t) {
if (leafPredicate.test(t)) {

View File

@ -85,6 +85,11 @@ final class XYPointInGeometryQuery extends Query {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(int docID, byte[] packedValue) {
double x = XYEncodingUtils.decode(packedValue, 0);

View File

@ -287,6 +287,17 @@ public abstract class PointValues {
*/
void visit(int docID) throws IOException;
/**
* Similar to {@link IntersectVisitor#visit(int)}, but a bulk visit and implements may have
* their optimizations.
*/
default void visit(DocIdSetIterator iterator) throws IOException {
int docID;
while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
visit(docID);
}
}
/**
* Called for all documents in a leaf cell that crosses the query. The consumer should
* scrutinize the packedValue to decide whether to accept it. In the 1D case, values are visited

View File

@ -237,6 +237,11 @@ public abstract class PointInSetQuery extends Query implements Accountable {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(int docID, byte[] packedValue) {
if (matches(packedValue)) {
@ -336,6 +341,11 @@ public abstract class PointInSetQuery extends Query implements Accountable {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(int docID, byte[] packedValue) {
assert packedValue.length == pointBytes.length;

View File

@ -178,6 +178,11 @@ public abstract class PointRangeQuery extends Query {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(int docID, byte[] packedValue) {
if (matches(packedValue)) {

View File

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import org.apache.lucene.search.DocIdSetIterator;
/**
* A {@link DocIdSetIterator} like {@link BitSetIterator} but has a doc base in onder to avoid
* storing previous 0s.
*/
public class DocBaseBitSetIterator extends DocIdSetIterator {
private final FixedBitSet bits;
private final int length;
private final long cost;
private final int docBase;
private int doc = -1;
public DocBaseBitSetIterator(FixedBitSet bits, long cost, int docBase) {
if (cost < 0) {
throw new IllegalArgumentException("cost must be >= 0, got " + cost);
}
if ((docBase & 63) != 0) {
throw new IllegalArgumentException("docBase need to be a multiple of 64");
}
this.bits = bits;
this.length = bits.length() + docBase;
this.cost = cost;
this.docBase = docBase;
}
/**
* Get the {@link FixedBitSet}. A docId will exist in this {@link DocIdSetIterator} if the bitset
* contains the (docId - {@link #getDocBase})
*
* @return the offset docId bitset
*/
public FixedBitSet getBitSet() {
return bits;
}
@Override
public int docID() {
return doc;
}
/**
* Get the docBase. It is guaranteed that docBase is a multiple of 64.
*
* @return the docBase
*/
public int getDocBase() {
return docBase;
}
@Override
public int nextDoc() {
return advance(doc + 1);
}
@Override
public int advance(int target) {
if (target >= length) {
return doc = NO_MORE_DOCS;
}
int next = bits.nextSetBit(Math.max(0, target - docBase));
if (next == NO_MORE_DOCS) {
return doc = NO_MORE_DOCS;
} else {
return doc = next + docBase;
}
}
@Override
public long cost() {
return cost;
}
}

View File

@ -43,6 +43,13 @@ public final class DocIdSetBuilder {
*/
public abstract static class BulkAdder {
public abstract void add(int doc);
public void add(DocIdSetIterator iterator) throws IOException {
int docID;
while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
add(docID);
}
}
}
private static class FixedBitSetAdder extends BulkAdder {
@ -56,6 +63,11 @@ public final class DocIdSetBuilder {
public void add(int doc) {
bitSet.set(doc);
}
@Override
public void add(DocIdSetIterator iterator) throws IOException {
bitSet.or(iterator);
}
}
private static class Buffer {

View File

@ -269,6 +269,10 @@ public final class FixedBitSet extends BitSet {
checkUnpositioned(iter);
final FixedBitSet bits = BitSetIterator.getFixedBitSetOrNull(iter);
or(bits);
} else if (iter instanceof DocBaseBitSetIterator) {
checkUnpositioned(iter);
DocBaseBitSetIterator baseIter = (DocBaseBitSetIterator) iter;
or(baseIter.getDocBase() >> 6, baseIter.getBitSet());
} else {
super.or(iter);
}
@ -276,15 +280,20 @@ public final class FixedBitSet extends BitSet {
/** this = this OR other */
public void or(FixedBitSet other) {
or(other.bits, other.numWords);
or(0, other.bits, other.numWords);
}
private void or(final long[] otherArr, final int otherNumWords) {
assert otherNumWords <= numWords : "numWords=" + numWords + ", otherNumWords=" + otherNumWords;
private void or(final int otherOffsetWords, FixedBitSet other) {
or(otherOffsetWords, other.bits, other.numWords);
}
private void or(final int otherOffsetWords, final long[] otherArr, final int otherNumWords) {
assert otherNumWords + otherOffsetWords <= numWords
: "numWords=" + numWords + ", otherNumWords=" + otherNumWords;
int pos = Math.min(numWords - otherOffsetWords, otherNumWords);
final long[] thisArr = this.bits;
int pos = Math.min(numWords, otherNumWords);
while (--pos >= 0) {
thisArr[pos] |= otherArr[pos];
thisArr[pos + otherOffsetWords] |= otherArr[pos];
}
}

View File

@ -18,8 +18,11 @@ package org.apache.lucene.util.bkd;
import java.io.IOException;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.DocBaseBitSetIterator;
import org.apache.lucene.util.FixedBitSet;
class DocIdsWriter {
@ -29,12 +32,26 @@ class DocIdsWriter {
// docs can be sorted either when all docs in a block have the same value
// or when a segment is sorted
boolean sorted = true;
boolean strictlySorted = true;
for (int i = 1; i < count; ++i) {
if (docIds[start + i - 1] > docIds[start + i]) {
sorted = false;
int last = docIds[start + i - 1];
int current = docIds[start + i];
if (last > current) {
sorted = strictlySorted = false;
break;
} else if (last == current) {
strictlySorted = false;
}
}
if (strictlySorted && (docIds[start + count - 1] - docIds[start] + 1) <= (count << 4)) {
// Only trigger this optimization when max - min + 1 <= 16 * count in order to avoid expanding
// too much storage.
// A field with lower cardinality will have higher probability to trigger this optimization.
out.writeByte((byte) -1);
writeIdsAsBitSet(docIds, start, count, out);
return;
}
if (sorted) {
out.writeByte((byte) 0);
int previous = 0;
@ -85,10 +102,46 @@ class DocIdsWriter {
}
}
private static void writeIdsAsBitSet(int[] docIds, int start, int count, DataOutput out)
throws IOException {
int min = docIds[start];
int max = docIds[start + count - 1];
final int offsetWords = min >> 6;
final int offsetBits = offsetWords << 6;
final int totalWordCount = FixedBitSet.bits2words(max - offsetBits + 1);
long currentWord = 0;
int currentWordIndex = 0;
out.writeVInt(offsetWords);
out.writeVInt(totalWordCount);
// build bit set streaming
for (int i = 0; i < count; i++) {
final int index = docIds[start + i] - offsetBits;
final int nextWordIndex = index >> 6;
assert currentWordIndex <= nextWordIndex;
if (currentWordIndex < nextWordIndex) {
out.writeLong(currentWord);
currentWord = 0L;
currentWordIndex++;
while (currentWordIndex < nextWordIndex) {
currentWordIndex++;
out.writeLong(0L);
}
}
currentWord |= 1L << index;
}
out.writeLong(currentWord);
assert currentWordIndex + 1 == totalWordCount;
}
/** Read {@code count} integers into {@code docIDs}. */
static void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
final int bpv = in.readByte();
switch (bpv) {
case -1:
readBitSet(in, count, docIDs);
break;
case 0:
readDeltaVInts(in, count, docIDs);
break;
@ -103,6 +156,24 @@ class DocIdsWriter {
}
}
private static DocIdSetIterator readBitSetIterator(IndexInput in, int count) throws IOException {
int offsetWords = in.readVInt();
int longLen = in.readVInt();
long[] bits = new long[longLen];
in.readLongs(bits, 0, longLen);
FixedBitSet bitSet = new FixedBitSet(bits, longLen << 6);
return new DocBaseBitSetIterator(bitSet, count, offsetWords << 6);
}
private static void readBitSet(IndexInput in, int count, int[] docIDs) throws IOException {
DocIdSetIterator iterator = readBitSetIterator(in, count);
int docId, pos = 0;
while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
docIDs[pos++] = docId;
}
assert pos == count : "pos: " + pos + "count: " + count;
}
private static void readDeltaVInts(IndexInput in, int count, int[] docIDs) throws IOException {
int doc = 0;
for (int i = 0; i < count; i++) {
@ -144,6 +215,9 @@ class DocIdsWriter {
static void readInts(IndexInput in, int count, IntersectVisitor visitor) throws IOException {
final int bpv = in.readByte();
switch (bpv) {
case -1:
readBitSet(in, count, visitor);
break;
case 0:
readDeltaVInts(in, count, visitor);
break;
@ -194,4 +268,10 @@ class DocIdsWriter {
visitor.visit((Short.toUnsignedInt(in.readShort()) << 8) | Byte.toUnsignedInt(in.readByte()));
}
}
private static void readBitSet(IndexInput in, int count, IntersectVisitor visitor)
throws IOException {
DocIdSetIterator bitSetIterator = readBitSetIterator(in, count);
visitor.visit(bitSetIterator);
}
}

View File

@ -18,6 +18,8 @@ package org.apache.lucene.util.bkd;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.store.Directory;
@ -58,6 +60,22 @@ public class TestDocIdsWriter extends LuceneTestCase {
}
}
public void testBitSet() throws Exception {
int numIters = atLeast(100);
try (Directory dir = newDirectory()) {
for (int iter = 0; iter < numIters; ++iter) {
int size = random().nextInt(5000);
Set<Integer> set = new HashSet<>(size);
int small = random().nextInt(1000);
while (set.size() < size) {
set.add(small + random().nextInt(size * 16));
}
int[] docIDs = set.stream().mapToInt(t -> t).sorted().toArray();
test(dir, docIDs);
}
}
}
private void test(Directory dir, int[] ints) throws Exception {
final long len;
try (IndexOutput out = dir.createOutput("tmp", IOContext.DEFAULT)) {

View File

@ -171,6 +171,11 @@ public abstract class MultiRangeQuery extends Query {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(int docID, byte[] packedValue) {
// If a single OR clause has the value in range, the entire query accepts the value

View File

@ -17,8 +17,10 @@
package org.apache.lucene.spatial3d;
import java.io.IOException;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.spatial3d.geom.GeoArea;
import org.apache.lucene.spatial3d.geom.GeoAreaFactory;
import org.apache.lucene.spatial3d.geom.GeoShape;
@ -60,6 +62,11 @@ class PointInShapeIntersectVisitor implements IntersectVisitor {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(int docID, byte[] packedValue) {
assert packedValue.length == 12;