Move BooleanScorer to work on top of Scorers rather than BulkScorers. (#13931)

I was looking at some queries where Lucene performs significantly worse than
Tantivy at https://tantivy-search.github.io/bench/, and found out that we get
quite some overhead from implementing `BooleanScorer` on top of `BulkScorer`
(effectively implemented by `DefaultBulkScorer` since it only runs term queries
as boolean clauses) rather than `Scorer` directly.

The `CountOrHighHigh` and `CountOrHighMed` tasks are a bit noisy on my machine,
so I did 3 runs on wikibigall, and all of them had speedups for these two
tasks, often with a very low p-value.

In theory, this change could make things slower when the inner query has a
specialized bulk scorer, such as `MatchAllDocsQuery` or a conjunction. It does
feel right to optimize for term queries though.
This commit is contained in:
Adrien Grand 2024-10-21 16:55:04 +02:00 committed by GitHub
parent 86457a5f33
commit a779a64d7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 118 additions and 197 deletions

View File

@ -55,6 +55,11 @@ Optimizations
* GITHUB#13930: Use growNoCopy when copying bytes in BytesRefBuilder. (Ignacio Vera)
* GITHUB#13931: Refactored `BooleanScorer` to evaluate matches of sub clauses
using the `Scorer` abstraction rather than the `BulkScorer` abstraction. This
speeds up exhaustive evaluation of disjunctions of term queries.
(Adrien Grand)
Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended

View File

@ -20,13 +20,14 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import org.apache.lucene.internal.hppc.LongArrayList;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.PriorityQueue;
/**
* {@link BulkScorer} that is used for pure disjunctions and disjunctions that have low values of
* {@link BooleanQuery.Builder#setMinimumNumberShouldMatch(int)} and dense clauses. This scorer
* scores documents by batches of 2048 docs.
* scores documents by batches of 4,096 docs.
*/
final class BooleanScorer extends BulkScorer {
@ -41,71 +42,32 @@ final class BooleanScorer extends BulkScorer {
int freq;
}
private class BulkScorerAndDoc {
final BulkScorer scorer;
final long cost;
int next;
BulkScorerAndDoc(BulkScorer scorer) {
this.scorer = scorer;
this.cost = scorer.cost();
this.next = -1;
}
void advance(int min) throws IOException {
score(orCollector, null, min, min);
}
void score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
next = scorer.score(collector, acceptDocs, min, max);
}
}
// See WANDScorer for an explanation
private static long cost(Collection<BulkScorer> scorers, int minShouldMatch) {
final PriorityQueue<BulkScorer> pq =
new PriorityQueue<BulkScorer>(scorers.size() - minShouldMatch + 1) {
@Override
protected boolean lessThan(BulkScorer a, BulkScorer b) {
return a.cost() > b.cost();
}
};
for (BulkScorer scorer : scorers) {
pq.insertWithOverflow(scorer);
}
long cost = 0;
for (BulkScorer scorer = pq.pop(); scorer != null; scorer = pq.pop()) {
cost += scorer.cost();
}
return cost;
}
static final class HeadPriorityQueue extends PriorityQueue<BulkScorerAndDoc> {
static final class HeadPriorityQueue extends PriorityQueue<DisiWrapper> {
public HeadPriorityQueue(int maxSize) {
super(maxSize);
}
@Override
protected boolean lessThan(BulkScorerAndDoc a, BulkScorerAndDoc b) {
return a.next < b.next;
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.doc < b.doc;
}
}
static final class TailPriorityQueue extends PriorityQueue<BulkScorerAndDoc> {
static final class TailPriorityQueue extends PriorityQueue<DisiWrapper> {
public TailPriorityQueue(int maxSize) {
super(maxSize);
}
@Override
protected boolean lessThan(BulkScorerAndDoc a, BulkScorerAndDoc b) {
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.cost < b.cost;
}
public BulkScorerAndDoc get(int i) {
public DisiWrapper get(int i) {
Objects.checkIndex(i, size());
return (BulkScorerAndDoc) getHeapArray()[1 + i];
return (DisiWrapper) getHeapArray()[1 + i];
}
}
@ -115,7 +77,7 @@ final class BooleanScorer extends BulkScorer {
// This is basically an inlined FixedBitSet... seems to help with bound checks
final long[] matching = new long[SET_SIZE];
final BulkScorerAndDoc[] leads;
final DisiWrapper[] leads;
final HeadPriorityQueue head;
final TailPriorityQueue tail;
final Score score = new Score();
@ -123,31 +85,6 @@ final class BooleanScorer extends BulkScorer {
final long cost;
final boolean needsScores;
final class OrCollector implements LeafCollector {
Scorable scorer;
@Override
public void setScorer(Scorable scorer) {
this.scorer = scorer;
}
@Override
public void collect(int doc) throws IOException {
final int i = doc & MASK;
final int idx = i >>> 6;
matching[idx] |= 1L << i;
if (buckets != null) {
final Bucket bucket = buckets[i];
bucket.freq++;
if (needsScores) {
bucket.score += scorer.score();
}
}
}
}
final OrCollector orCollector = new OrCollector();
final class DocIdStreamView extends DocIdStream {
int base;
@ -194,7 +131,7 @@ final class BooleanScorer extends BulkScorer {
private final DocIdStreamView docIdStreamView = new DocIdStreamView();
BooleanScorer(Collection<BulkScorer> scorers, int minShouldMatch, boolean needsScores) {
BooleanScorer(Collection<Scorer> scorers, int minShouldMatch, boolean needsScores) {
if (minShouldMatch < 1 || minShouldMatch > scorers.size()) {
throw new IllegalArgumentException(
"minShouldMatch should be within 1..num_scorers. Got " + minShouldMatch);
@ -211,18 +148,21 @@ final class BooleanScorer extends BulkScorer {
} else {
buckets = null;
}
this.leads = new BulkScorerAndDoc[scorers.size()];
this.leads = new DisiWrapper[scorers.size()];
this.head = new HeadPriorityQueue(scorers.size() - minShouldMatch + 1);
this.tail = new TailPriorityQueue(minShouldMatch - 1);
this.minShouldMatch = minShouldMatch;
this.needsScores = needsScores;
for (BulkScorer scorer : scorers) {
final BulkScorerAndDoc evicted = tail.insertWithOverflow(new BulkScorerAndDoc(scorer));
LongArrayList costs = new LongArrayList(scorers.size());
for (Scorer scorer : scorers) {
DisiWrapper w = new DisiWrapper(scorer);
costs.add(w.cost);
final DisiWrapper evicted = tail.insertWithOverflow(w);
if (evicted != null) {
head.add(evicted);
}
}
this.cost = cost(scorers, minShouldMatch);
this.cost = ScorerUtil.costWithMinShouldMatch(costs.stream(), costs.size(), minShouldMatch);
}
@Override
@ -230,19 +170,49 @@ final class BooleanScorer extends BulkScorer {
return cost;
}
private void scoreDisiWrapperIntoBitSet(DisiWrapper w, Bits acceptDocs, int min, int max)
throws IOException {
boolean needsScores = BooleanScorer.this.needsScores;
long[] matching = BooleanScorer.this.matching;
Bucket[] buckets = BooleanScorer.this.buckets;
DocIdSetIterator it = w.iterator;
Scorer scorer = w.scorer;
int doc = w.doc;
if (doc < min) {
doc = it.advance(min);
}
for (; doc < max; doc = it.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
final int i = doc & MASK;
final int idx = i >> 6;
matching[idx] |= 1L << i;
if (buckets != null) {
final Bucket bucket = buckets[i];
bucket.freq++;
if (needsScores) {
bucket.score += scorer.score();
}
}
}
}
w.doc = doc;
}
private void scoreWindowIntoBitSetAndReplay(
LeafCollector collector,
Bits acceptDocs,
int base,
int min,
int max,
BulkScorerAndDoc[] scorers,
DisiWrapper[] scorers,
int numScorers)
throws IOException {
for (int i = 0; i < numScorers; ++i) {
final BulkScorerAndDoc scorer = scorers[i];
assert scorer.next < max;
scorer.score(orCollector, acceptDocs, min, max);
final DisiWrapper w = scorers[i];
assert w.doc < max;
scoreDisiWrapperIntoBitSet(w, acceptDocs, min, max);
}
docIdStreamView.base = base;
@ -251,20 +221,20 @@ final class BooleanScorer extends BulkScorer {
Arrays.fill(matching, 0L);
}
private BulkScorerAndDoc advance(int min) throws IOException {
private DisiWrapper advance(int min) throws IOException {
assert tail.size() == minShouldMatch - 1;
final HeadPriorityQueue head = this.head;
final TailPriorityQueue tail = this.tail;
BulkScorerAndDoc headTop = head.top();
BulkScorerAndDoc tailTop = tail.top();
while (headTop.next < min) {
DisiWrapper headTop = head.top();
DisiWrapper tailTop = tail.top();
while (headTop.doc < min) {
if (tailTop == null || headTop.cost <= tailTop.cost) {
headTop.advance(min);
headTop.doc = headTop.iterator.advance(min);
headTop = head.updateTop();
} else {
// swap the top of head and tail
final BulkScorerAndDoc previousHeadTop = headTop;
tailTop.advance(min);
final DisiWrapper previousHeadTop = headTop;
tailTop.doc = tailTop.iterator.advance(min);
headTop = head.updateTop(tailTop);
tailTop = tail.updateTop(previousHeadTop);
}
@ -282,9 +252,11 @@ final class BooleanScorer extends BulkScorer {
throws IOException {
while (maxFreq < minShouldMatch && maxFreq + tail.size() >= minShouldMatch) {
// a match is still possible
final BulkScorerAndDoc candidate = tail.pop();
candidate.advance(windowMin);
if (candidate.next < windowMax) {
final DisiWrapper candidate = tail.pop();
if (candidate.doc < windowMin) {
candidate.doc = candidate.iterator.advance(windowMin);
}
if (candidate.doc < windowMax) {
leads[maxFreq++] = candidate;
} else {
head.add(candidate);
@ -304,7 +276,7 @@ final class BooleanScorer extends BulkScorer {
// Push back scorers into head and tail
for (int i = 0; i < maxFreq; ++i) {
final BulkScorerAndDoc evicted = head.insertWithOverflow(leads[i]);
final DisiWrapper evicted = head.insertWithOverflow(leads[i]);
if (evicted != null) {
tail.add(evicted);
}
@ -312,7 +284,7 @@ final class BooleanScorer extends BulkScorer {
}
private void scoreWindowSingleScorer(
BulkScorerAndDoc bulkScorer,
DisiWrapper w,
LeafCollector collector,
Bits acceptDocs,
int windowMin,
@ -320,33 +292,44 @@ final class BooleanScorer extends BulkScorer {
int max)
throws IOException {
assert tail.size() == 0;
final int nextWindowBase = head.top().next & ~MASK;
final int nextWindowBase = head.top().doc & ~MASK;
final int end = Math.max(windowMax, Math.min(max, nextWindowBase));
bulkScorer.score(collector, acceptDocs, windowMin, end);
DocIdSetIterator it = w.iterator;
int doc = w.doc;
if (doc < windowMin) {
doc = it.advance(windowMin);
}
collector.setScorer(w.scorer);
for (; doc < end; doc = it.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
collector.collect(doc);
}
}
w.doc = doc;
// reset the scorer that should be used for the general case
collector.setScorer(score);
}
private BulkScorerAndDoc scoreWindow(
BulkScorerAndDoc top, LeafCollector collector, Bits acceptDocs, int min, int max)
private DisiWrapper scoreWindow(
DisiWrapper top, LeafCollector collector, Bits acceptDocs, int min, int max)
throws IOException {
final int windowBase = top.next & ~MASK; // find the window that the next match belongs to
final int windowBase = top.doc & ~MASK; // find the window that the next match belongs to
final int windowMin = Math.max(min, windowBase);
final int windowMax = Math.min(max, windowBase + SIZE);
// Fill 'leads' with all scorers from 'head' that are in the right window
leads[0] = head.pop();
int maxFreq = 1;
while (head.size() > 0 && head.top().next < windowMax) {
while (head.size() > 0 && head.top().doc < windowMax) {
leads[maxFreq++] = head.pop();
}
if (minShouldMatch == 1 && maxFreq == 1) {
// special case: only one scorer can match in the current window,
// we can collect directly
final BulkScorerAndDoc bulkScorer = leads[0];
final DisiWrapper bulkScorer = leads[0];
scoreWindowSingleScorer(bulkScorer, collector, acceptDocs, windowMin, windowMax, max);
return head.add(bulkScorer);
} else {
@ -360,11 +343,11 @@ final class BooleanScorer extends BulkScorer {
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
collector.setScorer(score);
BulkScorerAndDoc top = advance(min);
while (top.next < max) {
DisiWrapper top = advance(min);
while (top.doc < max) {
top = scoreWindow(top, collector, acceptDocs, min, max);
}
return top.next;
return top.doc;
}
}

View File

@ -289,9 +289,9 @@ final class BooleanScorerSupplier extends ScorerSupplier {
return new MaxScoreBulkScorer(maxDoc, optionalScorers);
}
List<BulkScorer> optional = new ArrayList<BulkScorer>();
List<Scorer> optional = new ArrayList<Scorer>();
for (ScorerSupplier ss : subs.get(Occur.SHOULD)) {
optional.add(ss.bulkScorer());
optional.add(ss.get(Long.MAX_VALUE));
}
return new BooleanScorer(optional, Math.max(1, minShouldMatch), scoreMode.needsScores());

View File

@ -153,70 +153,6 @@ final class BooleanWeight extends Weight {
return MatchesUtils.fromSubMatches(matches);
}
// Return a BulkScorer for the optional clauses only,
// or null if it is not applicable
// pkg-private for forcing use of BooleanScorer in tests
BulkScorer optionalBulkScorer(LeafReaderContext context) throws IOException {
if (scoreMode == ScoreMode.TOP_SCORES) {
if (!query.isPureDisjunction()) {
return null;
}
List<ScorerSupplier> optional = new ArrayList<>();
for (WeightedBooleanClause wc : weightedClauses) {
Weight w = wc.weight;
BooleanClause c = wc.clause;
if (c.occur() != Occur.SHOULD) {
continue;
}
ScorerSupplier scorer = w.scorerSupplier(context);
if (scorer != null) {
optional.add(scorer);
}
}
if (optional.size() <= 1) {
return null;
}
List<Scorer> optionalScorers = new ArrayList<>();
for (ScorerSupplier ss : optional) {
optionalScorers.add(ss.get(Long.MAX_VALUE));
}
return new MaxScoreBulkScorer(context.reader().maxDoc(), optionalScorers);
}
List<BulkScorer> optional = new ArrayList<BulkScorer>();
for (WeightedBooleanClause wc : weightedClauses) {
Weight w = wc.weight;
BooleanClause c = wc.clause;
if (c.occur() != Occur.SHOULD) {
continue;
}
BulkScorer subScorer = w.bulkScorer(context);
if (subScorer != null) {
optional.add(subScorer);
}
}
if (optional.size() == 0) {
return null;
}
if (query.getMinimumNumberShouldMatch() > optional.size()) {
return null;
}
if (optional.size() == 1) {
return optional.get(0);
}
return new BooleanScorer(
optional, Math.max(1, query.getMinimumNumberShouldMatch()), scoreMode.needsScores());
}
@Override
public int count(LeafReaderContext context) throws IOException {
final int numDocs = context.reader().numDocs();

View File

@ -16,7 +16,6 @@
*/
package org.apache.lucene.search;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@ -33,8 +32,9 @@ import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IntArrayDocIdSet;
public class TestBooleanOr extends LuceneTestCase {
@ -205,34 +205,30 @@ public class TestBooleanOr extends LuceneTestCase {
dir.close();
}
private static BulkScorer scorer(int... matches) {
return new BulkScorer() {
final Score scorer = new Score();
int i = 0;
private static Scorer scorer(int... matches) throws IOException {
matches = ArrayUtil.growExact(matches, matches.length + 1);
matches[matches.length - 1] = DocIdSetIterator.NO_MORE_DOCS;
DocIdSetIterator it = new IntArrayDocIdSet(matches, matches.length - 1).iterator();
return new Scorer() {
@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max)
throws IOException {
collector.setScorer(scorer);
while (i < matches.length && matches[i] < min) {
i += 1;
}
while (i < matches.length && matches[i] < max) {
int doc = matches[i];
if (acceptDocs == null || acceptDocs.get(doc)) {
collector.collect(doc);
}
i += 1;
}
if (i == matches.length) {
return DocIdSetIterator.NO_MORE_DOCS;
}
return RandomNumbers.randomIntBetween(random(), max, matches[i]);
public DocIdSetIterator iterator() {
return it;
}
@Override
public long cost() {
return matches.length;
public int docID() {
return it.docID();
}
@Override
public float getMaxScore(int upTo) throws IOException {
return Float.MAX_VALUE;
}
@Override
public float score() throws IOException {
return 0;
}
};
}
@ -240,7 +236,7 @@ public class TestBooleanOr extends LuceneTestCase {
// Make sure that BooleanScorer keeps working even if the sub clauses return
// next matching docs which are less than the actual next match
public void testSubScorerNextIsNotMatch() throws IOException {
final List<BulkScorer> optionalScorers =
final List<Scorer> optionalScorers =
Arrays.asList(
scorer(100000, 1000001, 9999999),
scorer(4000, 1000051),

View File

@ -128,7 +128,8 @@ public class TestMinShouldMatch2 extends LuceneTestCase {
case SCORER:
return weight.scorer(reader.getContext());
case BULK_SCORER:
final BulkScorer bulkScorer = weight.optionalBulkScorer(reader.getContext());
final ScorerSupplier ss = weight.scorerSupplier(reader.getContext());
final BulkScorer bulkScorer = ss.bulkScorer();
if (bulkScorer == null) {
if (weight.scorer(reader.getContext()) != null) {
throw new AssertionError("BooleanScorer should be applicable for this query");