Use the new `loadIntoBitSet` API to speed up dense conjunctions. (#14080)

Now that loading doc IDs into a bit set is much more efficient thanks to
auto-vectorization, it has become tempting to evaluate dense conjunctions by
and-ing bit sets.
This commit is contained in:
Adrien Grand 2024-12-19 15:05:14 +01:00 committed by GitHub
parent aef16daa76
commit a337d14b21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 484 additions and 10 deletions

View File

@ -55,6 +55,9 @@ Optimizations
* GITHUB#14052: Speed up DisjunctionDISIApproximation#advance. (Adrien Grand) * GITHUB#14052: Speed up DisjunctionDISIApproximation#advance. (Adrien Grand)
* GITHUB#14080: Use the `DocIdSetIterator#loadIntoBitSet` API to speed up dense
conjunctions. (Adrien Grand)
Bug Fixes Bug Fixes
--------------------- ---------------------
(No changes) (No changes)

View File

@ -304,9 +304,9 @@ final class BooleanScorerSupplier extends ScorerSupplier {
BulkScorer filteredOptionalBulkScorer() throws IOException { BulkScorer filteredOptionalBulkScorer() throws IOException {
if (subs.get(Occur.MUST).isEmpty() == false if (subs.get(Occur.MUST).isEmpty() == false
|| subs.get(Occur.FILTER).isEmpty() || subs.get(Occur.FILTER).isEmpty()
|| scoreMode != ScoreMode.TOP_SCORES || (scoreMode.needsScores() && scoreMode != ScoreMode.TOP_SCORES)
|| subs.get(Occur.SHOULD).size() <= 1 || subs.get(Occur.SHOULD).size() <= 1
|| minShouldMatch > 1) { || minShouldMatch != 1) {
return null; return null;
} }
long cost = cost(); long cost = cost();
@ -318,6 +318,7 @@ final class BooleanScorerSupplier extends ScorerSupplier {
for (ScorerSupplier ss : subs.get(Occur.FILTER)) { for (ScorerSupplier ss : subs.get(Occur.FILTER)) {
filters.add(ss.get(cost)); filters.add(ss.get(cost));
} }
if (scoreMode == ScoreMode.TOP_SCORES) {
Scorer filterScorer; Scorer filterScorer;
if (filters.size() == 1) { if (filters.size() == 1) {
filterScorer = filters.iterator().next(); filterScorer = filters.iterator().next();
@ -325,6 +326,20 @@ final class BooleanScorerSupplier extends ScorerSupplier {
filterScorer = new ConjunctionScorer(filters, Collections.emptySet()); filterScorer = new ConjunctionScorer(filters, Collections.emptySet());
} }
return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer); return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer);
} else {
// In the beginning of this method, we exited early if the score mode is not either TOP_SCORES
// or a score mode that doesn't need scores.
assert scoreMode.needsScores() == false;
filters.add(new DisjunctionSumScorer(optionalScorers, scoreMode, cost));
if (filters.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
&& maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE
&& cost >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) {
return new DenseConjunctionBulkScorer(filters.stream().map(Scorer::iterator).toList());
}
return new DefaultBulkScorer(new ConjunctionScorer(filters, Collections.emptyList()));
}
} }
// Return a BulkScorer for the required clauses only // Return a BulkScorer for the required clauses only
@ -378,8 +393,15 @@ final class BooleanScorerSupplier extends ScorerSupplier {
&& requiredScoring.size() + requiredNoScoring.size() >= 2 && requiredScoring.size() + requiredNoScoring.size() >= 2
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull) && requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
&& requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) { && requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
if (requiredScoring.isEmpty()
&& maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE
&& leadCost >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) {
return new DenseConjunctionBulkScorer(
requiredNoScoring.stream().map(Scorer::iterator).toList());
} else {
return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring); return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring);
} }
}
if (scoreMode == ScoreMode.TOP_SCORES && requiredScoring.size() > 1) { if (scoreMode == ScoreMode.TOP_SCORES && requiredScoring.size() > 1) {
requiredScoring = Collections.singletonList(new BlockMaxConjunctionScorer(requiredScoring)); requiredScoring = Collections.singletonList(new BlockMaxConjunctionScorer(requiredScoring));
} }

View File

@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/**
* BulkScorer implementation of {@link ConjunctionScorer} that is specialized for dense clauses.
* Whenever sensible, it intersects clauses by loading their matches into a bit set and computing
* the intersection of clauses by and-ing these bit sets.
*/
final class DenseConjunctionBulkScorer extends BulkScorer {
// Use a small-ish window size to make sure that we can take advantage of gaps in the postings of
// clauses that are not leading iteration.
static final int WINDOW_SIZE = 4096;
// Only use bit sets to compute the intersection if more than 1/32th of the docs are expected to
// match. Experiments suggested that values that are a bit higher than this would work better, but
// we're erring on the conservative side.
static final int DENSITY_THRESHOLD_INVERSE = Long.SIZE / 2;
private final DocIdSetIterator lead;
private final List<DocIdSetIterator> others;
private final FixedBitSet windowMatches = new FixedBitSet(WINDOW_SIZE);
private final FixedBitSet clauseWindowMatches = new FixedBitSet(WINDOW_SIZE);
private final DocIdStreamView docIdStreamView = new DocIdStreamView();
DenseConjunctionBulkScorer(List<DocIdSetIterator> iterators) {
if (iterators.size() <= 1) {
throw new IllegalArgumentException("Expected 2 or more clauses, got " + iterators.size());
}
iterators = new ArrayList<>(iterators);
iterators.sort(Comparator.comparingLong(DocIdSetIterator::cost));
lead = iterators.get(0);
others = List.copyOf(iterators.subList(1, iterators.size()));
}
@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
for (DocIdSetIterator it : others) {
min = Math.max(min, it.docID());
}
if (lead.docID() < min) {
lead.advance(min);
}
if (lead.docID() >= max) {
return lead.docID();
}
List<DocIdSetIterator> otherIterators = this.others;
DocIdSetIterator collectorIterator = collector.competitiveIterator();
if (collectorIterator != null) {
otherIterators = new ArrayList<>(otherIterators);
otherIterators.add(collectorIterator);
}
final DocIdSetIterator[] others = otherIterators.toArray(DocIdSetIterator[]::new);
int windowMax;
do {
windowMax = (int) Math.min(max, (long) lead.docID() + WINDOW_SIZE);
scoreWindowUsingBitSet(collector, acceptDocs, others, windowMax);
} while (windowMax < max);
return lead.docID();
}
private static int advance(FixedBitSet set, int i) {
if (i >= WINDOW_SIZE) {
return DocIdSetIterator.NO_MORE_DOCS;
} else {
return set.nextSetBit(i);
}
}
private void scoreWindowUsingBitSet(
LeafCollector collector, Bits acceptDocs, DocIdSetIterator[] others, int max)
throws IOException {
assert windowMatches.scanIsEmpty();
assert clauseWindowMatches.scanIsEmpty();
int offset = lead.docID();
lead.intoBitSet(acceptDocs, max, windowMatches, offset);
int upTo = 0;
for (;
upTo < others.length
&& windowMatches.cardinality() >= WINDOW_SIZE / DENSITY_THRESHOLD_INVERSE;
upTo++) {
DocIdSetIterator other = others[upTo];
if (other.docID() < offset) {
other.advance(offset);
}
// No need to apply acceptDocs on other clauses since we already applied live docs on the
// leading clause.
other.intoBitSet(null, max, clauseWindowMatches, offset);
windowMatches.and(clauseWindowMatches);
clauseWindowMatches.clear();
}
if (upTo < others.length) {
// If the leading clause is sparse on this doc ID range or if the intersection became sparse
// after applying a few clauses, we finish evaluating the intersection using the traditional
// leap-frog approach. This proved important with a query such as "+secretary +of +state" on
// wikibigall, where the intersection becomes sparse after intersecting "secretary" and
// "state".
advanceHead:
for (int windowMatch = windowMatches.nextSetBit(0);
windowMatch != DocIdSetIterator.NO_MORE_DOCS; ) {
int doc = offset + windowMatch;
for (int i = upTo; i < others.length; ++i) {
DocIdSetIterator other = others[i];
int otherDoc = other.docID();
if (otherDoc < doc) {
otherDoc = other.advance(doc);
}
if (doc != otherDoc) {
int clearUpTo = Math.min(WINDOW_SIZE, otherDoc - offset);
windowMatches.clear(windowMatch, clearUpTo);
windowMatch = advance(windowMatches, clearUpTo);
continue advanceHead;
}
}
windowMatch = advance(windowMatches, windowMatch + 1);
}
}
docIdStreamView.offset = offset;
collector.collect(docIdStreamView);
windowMatches.clear();
// If another clause is more advanced than lead1 then advance lead1, it's important to take
// advantage of large gaps in the postings lists of other clauses.
int maxOtherDocID = -1;
for (DocIdSetIterator other : others) {
maxOtherDocID = Math.max(maxOtherDocID, other.docID());
}
if (lead.docID() < maxOtherDocID) {
lead.advance(maxOtherDocID);
}
}
@Override
public long cost() {
return lead.cost();
}
final class DocIdStreamView extends DocIdStream {
int offset;
@Override
public void forEach(CheckedIntConsumer<IOException> consumer) throws IOException {
int offset = this.offset;
long[] bitArray = windowMatches.getBits();
for (int idx = 0; idx < bitArray.length; idx++) {
long bits = bitArray[idx];
while (bits != 0L) {
int ntz = Long.numberOfTrailingZeros(bits);
consumer.accept(offset + ((idx << 6) | ntz));
bits ^= 1L << ntz;
}
}
}
@Override
public int count() throws IOException {
return windowMatches.cardinality();
}
}
}

View File

@ -21,6 +21,8 @@ import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/** /**
* A {@link DocIdSetIterator} which is a disjunction of the approximations of the provided * A {@link DocIdSetIterator} which is a disjunction of the approximations of the provided
@ -141,6 +143,23 @@ public final class DisjunctionDISIApproximation extends DocIdSetIterator {
return Math.min(leadTop.doc, minOtherDoc); return Math.min(leadTop.doc, minOtherDoc);
} }
@Override
public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
throws IOException {
while (leadTop.doc < upTo) {
leadTop.approximation.intoBitSet(acceptDocs, upTo, bitSet, offset);
leadTop.doc = leadTop.approximation.docID();
leadTop = leadIterators.updateTop();
}
minOtherDoc = Integer.MAX_VALUE;
for (DisiWrapper w : otherIterators) {
w.approximation.intoBitSet(acceptDocs, upTo, bitSet, offset);
w.doc = w.approximation.docID();
minOtherDoc = Math.min(minOtherDoc, w.doc);
}
}
/** Return the linked list of iterators positioned on the current doc. */ /** Return the linked list of iterators positioned on the current doc. */
public DisiWrapper topList() { public DisiWrapper topList() {
if (leadTop.doc < minOtherDoc) { if (leadTop.doc < minOtherDoc) {

View File

@ -33,6 +33,10 @@ public final class FixedBitSet extends BitSet {
private static final long BASE_RAM_BYTES_USED = private static final long BASE_RAM_BYTES_USED =
RamUsageEstimator.shallowSizeOfInstance(FixedBitSet.class); RamUsageEstimator.shallowSizeOfInstance(FixedBitSet.class);
// An array that is small enough to use reasonable amounts of RAM and large enough to allow
// Arrays#mismatch to use SIMD instructions and multiple registers under the hood.
private static long[] ZEROES = new long[32];
private final long[] bits; // Array of longs holding the bits private final long[] bits; // Array of longs holding the bits
private final int numBits; // The number of bits in use private final int numBits; // The number of bits in use
private final int numWords; // The exact number of longs needed to hold numBits (<= bits.length) private final int numWords; // The exact number of longs needed to hold numBits (<= bits.length)
@ -470,8 +474,11 @@ public final class FixedBitSet extends BitSet {
// Depends on the ghost bits being clear! // Depends on the ghost bits being clear!
final int count = numWords; final int count = numWords;
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i += ZEROES.length) {
if (bits[i] != 0) return false; int cmpLen = Math.min(ZEROES.length, bits.length - i);
if (Arrays.equals(bits, i, i + cmpLen, ZEROES, 0, cmpLen) == false) {
return false;
}
} }
return true; return true;

View File

@ -0,0 +1,204 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.FixedBitSet;
public class TestDenseConjunctionBulkScorer extends LuceneTestCase {
public void testSameMatches() throws IOException {
int maxDoc = 100_000;
FixedBitSet clause1 = new FixedBitSet(maxDoc);
FixedBitSet clause2 = new FixedBitSet(maxDoc);
FixedBitSet clause3 = new FixedBitSet(maxDoc);
for (int i = 0; i < maxDoc; i += 2) {
clause1.set(i);
clause2.set(i);
clause3.set(i);
}
DenseConjunctionBulkScorer scorer =
new DenseConjunctionBulkScorer(
Arrays.asList(
new BitSetIterator(clause1, clause1.approximateCardinality()),
new BitSetIterator(clause2, clause2.approximateCardinality()),
new BitSetIterator(clause3, clause3.approximateCardinality())));
FixedBitSet result = new FixedBitSet(maxDoc);
scorer.score(
new LeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {}
@Override
public void collect(int doc) throws IOException {
result.set(doc);
}
},
null,
0,
DocIdSetIterator.NO_MORE_DOCS);
assertEquals(clause1, result);
}
public void testApplyAcceptDocs() throws IOException {
int maxDoc = 100_000;
FixedBitSet clause1 = new FixedBitSet(maxDoc);
FixedBitSet clause2 = new FixedBitSet(maxDoc);
clause1.set(0, maxDoc);
clause2.set(0, maxDoc);
FixedBitSet acceptDocs = new FixedBitSet(maxDoc);
for (int i = 0; i < maxDoc; i += 2) {
acceptDocs.set(i);
}
DenseConjunctionBulkScorer scorer =
new DenseConjunctionBulkScorer(
Arrays.asList(
new BitSetIterator(clause1, clause1.approximateCardinality()),
new BitSetIterator(clause2, clause2.approximateCardinality())));
FixedBitSet result = new FixedBitSet(maxDoc);
scorer.score(
new LeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {}
@Override
public void collect(int doc) throws IOException {
result.set(doc);
}
},
acceptDocs,
0,
DocIdSetIterator.NO_MORE_DOCS);
assertEquals(acceptDocs, result);
}
public void testEmptyIntersection() throws IOException {
int maxDoc = 100_000;
FixedBitSet clause1 = new FixedBitSet(maxDoc);
FixedBitSet clause2 = new FixedBitSet(maxDoc);
for (int i = 0; i < maxDoc - 1; i += 2) {
clause1.set(i);
clause2.set(i + 1);
}
DenseConjunctionBulkScorer scorer =
new DenseConjunctionBulkScorer(
Arrays.asList(
new BitSetIterator(clause1, clause1.approximateCardinality()),
new BitSetIterator(clause2, clause2.approximateCardinality())));
FixedBitSet result = new FixedBitSet(maxDoc);
scorer.score(
new LeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {}
@Override
public void collect(int doc) throws IOException {
result.set(doc);
}
},
null,
0,
DocIdSetIterator.NO_MORE_DOCS);
assertTrue(result.scanIsEmpty());
}
public void testClustered() throws IOException {
int maxDoc = 100_000;
FixedBitSet clause1 = new FixedBitSet(maxDoc);
FixedBitSet clause2 = new FixedBitSet(maxDoc);
FixedBitSet clause3 = new FixedBitSet(maxDoc);
clause1.set(10_000, 90_000);
clause2.set(0, 80_000);
clause3.set(20_000, 100_000);
DenseConjunctionBulkScorer scorer =
new DenseConjunctionBulkScorer(
Arrays.asList(
new BitSetIterator(clause1, clause1.approximateCardinality()),
new BitSetIterator(clause2, clause2.approximateCardinality()),
new BitSetIterator(clause3, clause3.approximateCardinality())));
FixedBitSet result = new FixedBitSet(maxDoc);
scorer.score(
new LeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {}
@Override
public void collect(int doc) throws IOException {
result.set(doc);
}
},
null,
0,
DocIdSetIterator.NO_MORE_DOCS);
FixedBitSet expected = new FixedBitSet(maxDoc);
expected.set(20_000, 80_000);
assertArrayEquals(expected.getBits(), result.getBits());
assertEquals(expected, result);
}
public void testSparseAfter2ndClause() throws IOException {
int maxDoc = 100_000;
FixedBitSet clause1 = new FixedBitSet(maxDoc);
FixedBitSet clause2 = new FixedBitSet(maxDoc);
FixedBitSet clause3 = new FixedBitSet(maxDoc);
// 13 and 17 are primes, so their only intersection is on multiples of both 13 and 17
// Likewise, 19 is prime, so the only intersection of the conjunction is on multiples of 13, 17
// and 19
for (int i = 0; i < maxDoc; i += 13) {
clause1.set(i);
}
for (int i = 0; i < maxDoc; i += 17) {
clause2.set(i);
}
for (int i = 0; i < maxDoc; i += 19) {
clause3.set(i);
}
DenseConjunctionBulkScorer scorer =
new DenseConjunctionBulkScorer(
Arrays.asList(
new BitSetIterator(clause1, clause1.approximateCardinality()),
new BitSetIterator(clause2, clause2.approximateCardinality()),
new BitSetIterator(clause3, clause3.approximateCardinality())));
FixedBitSet result = new FixedBitSet(maxDoc);
scorer.score(
new LeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {}
@Override
public void collect(int doc) throws IOException {
result.set(doc);
}
},
null,
0,
DocIdSetIterator.NO_MORE_DOCS);
FixedBitSet expected = new FixedBitSet(maxDoc);
for (int i = 0; i < maxDoc; i += 13 * 17 * 19) {
expected.set(i);
}
assertEquals(expected, result);
}
}

View File

@ -616,4 +616,30 @@ public class TestFixedBitSet extends BaseBitSetTestCase<FixedBitSet> {
set.set(5); set.set(5);
assertTrue(bits.get(5)); assertTrue(bits.get(5));
} }
public void testScanIsEmpty() {
FixedBitSet set = new FixedBitSet(0);
assertTrue(set.scanIsEmpty());
set = new FixedBitSet(13);
assertTrue(set.scanIsEmpty());
set.set(10);
assertFalse(set.scanIsEmpty());
set = new FixedBitSet(1024);
assertTrue(set.scanIsEmpty());
set.set(3);
assertFalse(set.scanIsEmpty());
set.clear(3);
set.set(1020);
assertFalse(set.scanIsEmpty());
set = new FixedBitSet(1030);
assertTrue(set.scanIsEmpty());
set.set(3);
assertFalse(set.scanIsEmpty());
set.clear(3);
set.set(1028);
assertFalse(set.scanIsEmpty());
}
} }