Let `DocIdSetIterator` optimize loading into a FixedBitSet. (#14069)

This is an iteration on #14064. The benefits of this approach are that the API
is a bit nicer and allows optimizing not only when doc IDs are stored in an
int[]. The downside is that it only helps non-scoring disjunctions for now, but
we can look into scoring disjunctions later on.
This commit is contained in:
Adrien Grand 2024-12-17 22:22:49 +01:00 committed by GitHub
parent 5f0fa2b291
commit e74f19bf77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 175 additions and 24 deletions

View File

@ -35,7 +35,8 @@ Other
API Changes
---------------------
(No changes)
* GITHUB#14069: Added DocIdSetIterator#intoBitSet API to let implementations
optimize loading doc IDs into a bit set. (Adrien Grand)
New Features
---------------------

View File

@ -53,7 +53,9 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.VectorUtil;
@ -875,6 +877,63 @@ public final class Lucene101PostingsReader extends PostingsReaderBase {
return doc;
}
@Override
public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
throws IOException {
if (doc >= upTo) {
return;
}
// Handle the current doc separately, it may be on the previous docBuffer.
if (acceptDocs == null || acceptDocs.get(doc)) {
bitSet.set(doc - offset);
}
for (; ; ) {
if (docBufferUpto == BLOCK_SIZE) {
// refill
moveToNextLevel0Block();
}
int start = docBufferUpto;
int end = computeBufferEndBoundary(upTo);
if (end != 0) {
bufferIntoBitSet(start, end, acceptDocs, bitSet, offset);
doc = docBuffer[end - 1];
}
docBufferUpto = end;
if (end != BLOCK_SIZE) {
// Either the block is a tail block, or the block did not fully match, we're done.
nextDoc();
assert doc >= upTo;
break;
}
}
}
private int computeBufferEndBoundary(int upTo) {
if (docBufferSize != 0 && docBuffer[docBufferSize - 1] < upTo) {
// All docs in the buffer are under upTo
return docBufferSize;
} else {
// Find the index of the first doc that is greater than or equal to upTo
return VectorUtil.findNextGEQ(docBuffer, upTo, docBufferUpto, docBufferSize);
}
}
private void bufferIntoBitSet(
int start, int end, Bits acceptDocs, FixedBitSet bitSet, int offset) throws IOException {
// acceptDocs#get (if backed by FixedBitSet), bitSet#set and `doc - offset` get
// auto-vectorized
for (int i = start; i < end; ++i) {
int doc = docBuffer[i];
if (acceptDocs == null || acceptDocs.get(doc)) {
bitSet.set(doc - offset);
}
}
}
private void skipPositions(int freq) throws IOException {
// Skip positions now:
int toSkip = posPendingCount - freq;

View File

@ -17,11 +17,11 @@
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import org.apache.lucene.internal.hppc.LongArrayList;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.PriorityQueue;
/**
@ -34,8 +34,6 @@ final class BooleanScorer extends BulkScorer {
static final int SHIFT = 12;
static final int SIZE = 1 << SHIFT;
static final int MASK = SIZE - 1;
static final int SET_SIZE = 1 << (SHIFT - 6);
static final int SET_MASK = SET_SIZE - 1;
static class Bucket {
double score;
@ -74,8 +72,7 @@ final class BooleanScorer extends BulkScorer {
// One bucket per doc ID in the window, non-null if scores are needed or if frequencies need to be
// counted
final Bucket[] buckets;
// This is basically an inlined FixedBitSet... seems to help with bound checks
final long[] matching = new long[SET_SIZE];
final FixedBitSet matching = new FixedBitSet(SIZE);
final DisiWrapper[] leads;
final HeadPriorityQueue head;
@ -91,11 +88,12 @@ final class BooleanScorer extends BulkScorer {
@Override
public void forEach(CheckedIntConsumer<IOException> consumer) throws IOException {
long[] matching = BooleanScorer.this.matching;
FixedBitSet matching = BooleanScorer.this.matching;
Bucket[] buckets = BooleanScorer.this.buckets;
int base = this.base;
for (int idx = 0; idx < matching.length; idx++) {
long bits = matching[idx];
long[] bitArray = matching.getBits();
for (int idx = 0; idx < bitArray.length; idx++) {
long bits = bitArray[idx];
while (bits != 0L) {
int ntz = Long.numberOfTrailingZeros(bits);
if (buckets != null) {
@ -121,11 +119,7 @@ final class BooleanScorer extends BulkScorer {
// We can't just count bits in that case
return super.count();
}
int count = 0;
for (long l : matching) {
count += Long.bitCount(l);
}
return count;
return matching.cardinality();
}
}
@ -173,7 +167,7 @@ final class BooleanScorer extends BulkScorer {
private void scoreDisiWrapperIntoBitSet(DisiWrapper w, Bits acceptDocs, int min, int max)
throws IOException {
boolean needsScores = BooleanScorer.this.needsScores;
long[] matching = BooleanScorer.this.matching;
FixedBitSet matching = BooleanScorer.this.matching;
Bucket[] buckets = BooleanScorer.this.buckets;
DocIdSetIterator it = w.iterator;
@ -182,12 +176,13 @@ final class BooleanScorer extends BulkScorer {
if (doc < min) {
doc = it.advance(min);
}
for (; doc < max; doc = it.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
final int i = doc & MASK;
final int idx = i >> 6;
matching[idx] |= 1L << i;
if (buckets != null) {
if (buckets == null) {
it.intoBitSet(acceptDocs, max, matching, doc & ~MASK);
} else {
for (; doc < max; doc = it.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
final int i = doc & MASK;
matching.set(i);
final Bucket bucket = buckets[i];
bucket.freq++;
if (needsScores) {
@ -197,7 +192,7 @@ final class BooleanScorer extends BulkScorer {
}
}
w.doc = doc;
w.doc = it.docID();
}
private void scoreWindowIntoBitSetAndReplay(
@ -218,7 +213,7 @@ final class BooleanScorer extends BulkScorer {
docIdStreamView.base = base;
collector.collect(docIdStreamView);
Arrays.fill(matching, 0L);
matching.clear();
}
private DisiWrapper advance(int min) throws IOException {

View File

@ -17,6 +17,8 @@
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/**
* This abstract class defines methods to iterate over a set of non-decreasing doc ids. Note that
@ -211,4 +213,33 @@ public abstract class DocIdSetIterator {
* may be a rough heuristic, hardcoded value, or otherwise completely inaccurate.
*/
public abstract long cost();
/**
* Load doc IDs into a {@link FixedBitSet}. This should behave exactly as if implemented as below,
* which is the default implementation:
*
* <pre class="prettyprint">
* for (int doc = docID(); doc &lt; upTo; doc = nextDoc()) {
* if (acceptDocs == null || acceptDocs.get(doc)) {
* bitSet.set(doc - offset);
* }
* }
* </pre>
*
* <p><b>Note</b>: {@code offset} must be less than or equal to the {@link #docID() current doc
* ID}.
*
* <p><b>Note</b>: It is important not to clear bits from {@code bitSet} that may be already set.
*
* @lucene.internal
*/
public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
throws IOException {
assert offset <= docID();
for (int doc = docID(); doc < upTo; doc = nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
bitSet.set(doc - offset);
}
}
}
}

View File

@ -343,7 +343,9 @@ public final class FixedBitSet extends BitSet {
DocBaseBitSetIterator baseIter = (DocBaseBitSetIterator) iter;
or(baseIter.getDocBase() >> 6, baseIter.getBitSet());
} else {
super.or(iter);
checkUnpositioned(iter);
iter.nextDoc();
iter.intoBitSet(null, DocIdSetIterator.NO_MORE_DOCS, this, 0);
}
}

View File

@ -75,6 +75,7 @@ import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.tests.util.automaton.AutomatonTestUtil;
import org.apache.lucene.tests.util.automaton.AutomatonTestUtil.RandomAcceptedStrings;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.StringHelper;
import org.apache.lucene.util.UnicodeUtil;
@ -110,6 +111,9 @@ public class RandomPostingsTester {
// Sometimes don't fully consume positions at each doc
PARTIAL_POS_CONSUME,
// Check DocIdSetIterator#intoBitSet
INTO_BIT_SET,
// Sometimes check payloads
PAYLOADS,
@ -1364,6 +1368,54 @@ public class RandomPostingsTester {
idx <= impactsCopy.size() && impactsCopy.get(idx).norm <= norm);
}
}
if (options.contains(Option.INTO_BIT_SET)) {
int flags = PostingsEnum.FREQS;
if (doCheckPositions) {
flags |= PostingsEnum.POSITIONS;
if (doCheckOffsets) {
flags |= PostingsEnum.OFFSETS;
}
if (doCheckPayloads) {
flags |= PostingsEnum.PAYLOADS;
}
}
PostingsEnum pe1 = termsEnum.postings(null, flags);
if (random.nextBoolean()) {
pe1.advance(maxDoc / 2);
pe1 = termsEnum.postings(pe1, flags);
}
PostingsEnum pe2 = termsEnum.postings(null, flags);
FixedBitSet set1 = new FixedBitSet(1024);
FixedBitSet set2 = new FixedBitSet(1024);
FixedBitSet acceptDocs = new FixedBitSet(maxDoc);
for (int i = 0; i < maxDoc; i += 2) {
acceptDocs.set(i);
}
while (true) {
pe1.nextDoc();
pe2.nextDoc();
int offset =
TestUtil.nextInt(random, Math.max(0, pe1.docID() - set1.length()), pe1.docID());
int upTo = offset + random.nextInt(set1.length());
pe1.intoBitSet(acceptDocs, upTo, set1, offset);
for (int d = pe2.docID(); d < upTo; d = pe2.nextDoc()) {
if (acceptDocs.get(d)) {
set2.set(d - offset);
}
}
assertEquals(set1, set2);
assertEquals(pe1.docID(), pe2.docID());
if (pe1.docID() == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
set1.clear();
set2.clear();
}
}
}
private static class TestThread extends Thread {

View File

@ -24,6 +24,8 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/** Wraps a Scorer with additional checks */
public class AssertingScorer extends Scorer {
@ -192,6 +194,15 @@ public class AssertingScorer extends Scorer {
public long cost() {
return in.cost();
}
@Override
public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
throws IOException {
assert docID() != -1;
assert offset <= docID();
in.intoBitSet(acceptDocs, upTo, bitSet, offset);
assert docID() >= upTo;
}
};
}