Add BS1 optimization to MaxScoreBulkScorer. (#12444)

Lucene's scorers that can dynamically prune on score provide great speedups
when they manage to skip many hits. Unfortunately, there are also cases when
they cannot skip hits efficiently, one example case being when there are many
clauses in the query. In this case, exhaustively evaluating the set of matches
with `BooleanScorer` (BS1) may perform several times faster.

This commit adds to `MaxScoreBulkScorer` the BS1 optimization that consists of
collecting hits into a bitset to save the overhead of reordering priority
queues. This helps make performance degrade much more gracefully when dynamic
pruning cannot help much.

Closes #12439
This commit is contained in:
Adrien Grand 2023-07-19 07:51:22 -04:00 committed by GitHub
parent 55f2f9958b
commit 17c13a76c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 170 additions and 63 deletions

View File

@ -147,6 +147,10 @@ Optimizations
* GITHUB#12361: Faster top-level disjunctions sorted by descending score.
(Adrien Grand)
* GITHUB#12444: Faster top-level disjunctions sorted by descending score in
case of many terms or queries that expose suboptimal score upper bounds.
(Adrien Grand)
* GITHUB#12383: Assign a dummy simScorer in TermsWeight if score is not needed. (Sagar Upadhyaya)
* GITHUB#12372: Reduce allocation during HNSW construction (Jonathan Ellis)

View File

@ -191,7 +191,7 @@ final class BooleanWeight extends Weight {
// pkg-private for forcing use of BooleanScorer in tests
BulkScorer optionalBulkScorer(LeafReaderContext context) throws IOException {
if (scoreMode == ScoreMode.TOP_SCORES) {
if (!query.isPureDisjunction() || weightedClauses.size() > 2) {
if (!query.isPureDisjunction()) {
return null;
}

View File

@ -57,6 +57,23 @@ public final class DisiPriorityQueue implements Iterable<DisiWrapper> {
return heap[0];
}
/** Return the 2nd least value in this heap, or null if the heap contains less than 2 values. */
public DisiWrapper top2() {
switch (size()) {
case 0:
case 1:
return null;
case 2:
return heap[1];
default:
if (heap[1].doc <= heap[2].doc) {
return heap[1];
} else {
return heap[2];
}
}
}
/** Get the list of scorers which are on the current doc. */
public DisiWrapper topList() {
final DisiWrapper[] heap = this.heap;

View File

@ -21,9 +21,12 @@ import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
final class MaxScoreBulkScorer extends BulkScorer {
static final int INNER_WINDOW_SIZE = 1 << 11;
private final int maxDoc;
// All scorers, sorted by increasing max score.
private final DisiWrapper[] allScorers;
@ -40,6 +43,9 @@ final class MaxScoreBulkScorer extends BulkScorer {
private Score scorable = new Score();
private final double[] maxScoreSums;
private final long[] windowMatches = new long[FixedBitSet.bits2words(INNER_WINDOW_SIZE)];
private final double[] windowScores = new double[INNER_WINDOW_SIZE];
MaxScoreBulkScorer(int maxDoc, List<Scorer> scorers) throws IOException {
this.maxDoc = maxDoc;
allScorers = new DisiWrapper[scorers.size()];
@ -60,75 +66,121 @@ final class MaxScoreBulkScorer extends BulkScorer {
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
collector.setScorer(scorable);
int windowMin = min;
main:
while (windowMin < max) {
int windowMax = updateMaxWindowScores(windowMin);
windowMax = Math.min(windowMax, max);
// This scorer computes outer windows based on impacts that are stored in the index. These outer
// windows should be small enough to provide good upper bounds of scores, and big enough to make
// sure we spend more time collecting docs than recomputing windows.
// Then within these outer windows, it creates inner windows of size WINDOW_SIZE that help
// collect matches into a bitset and save the overhead of rebalancing the priority queue on
// every match.
int outerWindowMin = min;
outer:
while (outerWindowMin < max) {
int outerWindowMax = updateMaxWindowScores(outerWindowMin);
outerWindowMax = Math.min(outerWindowMax, max);
if (partitionScorers() == false) {
// No matches in this window
windowMin = windowMax;
outerWindowMin = outerWindowMax;
continue;
}
DisiWrapper top = essentialQueue.top();
while (top.doc < windowMin) {
top.doc = top.iterator.advance(windowMin);
while (top.doc < outerWindowMin) {
top.doc = top.iterator.advance(outerWindowMin);
top = essentialQueue.updateTop();
}
while (top.doc < windowMax) {
if (acceptDocs == null || acceptDocs.get(top.doc)) {
DisiWrapper topList = essentialQueue.topList();
double score = topList.scorer.score();
for (DisiWrapper w = topList.next; w != null; w = w.next) {
score += w.scorer.score();
}
boolean possibleMatch = true;
for (int i = firstEssentialScorer - 1; i >= 0; --i) {
float maxPossibleScore = maxScorePropagator.scoreSumUpperBound(score + maxScoreSums[i]);
if (maxPossibleScore < minCompetitiveScore) {
possibleMatch = false;
break;
}
DisiWrapper scorer = allScorers[i];
if (scorer.doc < top.doc) {
scorer.doc = scorer.iterator.advance(top.doc);
}
if (scorer.doc == top.doc) {
score += scorer.scorer.score();
}
}
if (possibleMatch) {
scorable.score = (float) score;
collector.collect(top.doc);
}
}
int doc = top.doc;
do {
top.doc = top.iterator.nextDoc();
top = essentialQueue.updateTop();
} while (top.doc == doc);
while (top.doc < outerWindowMax) {
scoreInnerWindow(collector, acceptDocs, outerWindowMax);
if (minCompetitiveScoreUpdated) {
minCompetitiveScoreUpdated = false;
if (partitionScorers()) {
top = essentialQueue.top();
} else {
windowMin = windowMax;
continue main;
if (partitionScorers() == false) {
outerWindowMin = outerWindowMax;
continue outer;
}
}
top = essentialQueue.top();
}
windowMin = windowMax;
outerWindowMin = outerWindowMax;
}
return nextCandidate(max);
}
private void scoreInnerWindow(LeafCollector collector, Bits acceptDocs, int max)
throws IOException {
DisiWrapper top = essentialQueue.top();
DisiWrapper top2 = essentialQueue.top2();
if (top2 == null) {
scoreInnerWindowSingleEssentialClause(collector, acceptDocs, max);
} else if (top2.doc - INNER_WINDOW_SIZE / 2 >= top.doc) {
// The first half of the window would match a single clause. Let's collect this single clause
// until the next doc ID of the next clause.
scoreInnerWindowSingleEssentialClause(collector, acceptDocs, Math.min(max, top2.doc));
} else {
scoreInnerWindowMultipleEssentialClauses(collector, acceptDocs, max);
}
}
private void scoreInnerWindowSingleEssentialClause(
LeafCollector collector, Bits acceptDocs, int upTo) throws IOException {
DisiWrapper top = essentialQueue.top();
// single essential clause in this window, we can iterate it directly and skip the bitset.
// this is a common case for 2-clauses queries
for (int doc = top.doc; doc < upTo; doc = top.iterator.nextDoc()) {
if (acceptDocs != null && acceptDocs.get(doc) == false) {
continue;
}
scoreNonEssentialClauses(collector, doc, top.scorer.score());
if (minCompetitiveScoreUpdated) {
// force scorers to be partitioned again before collecting more hits
top.iterator.nextDoc();
break;
}
}
top.doc = top.iterator.docID();
essentialQueue.updateTop();
}
private void scoreInnerWindowMultipleEssentialClauses(
LeafCollector collector, Bits acceptDocs, int max) throws IOException {
DisiWrapper top = essentialQueue.top();
int innerWindowMin = top.doc;
int innerWindowMax = (int) Math.min(max, (long) innerWindowMin + INNER_WINDOW_SIZE);
// Collect matches of essential clauses into a bitset
do {
for (int doc = top.doc; doc < innerWindowMax; doc = top.iterator.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
final int i = doc - innerWindowMin;
windowMatches[i >>> 6] |= 1L << i;
windowScores[i] += top.scorer.score();
}
}
top.doc = top.iterator.docID();
top = essentialQueue.updateTop();
} while (top.doc < innerWindowMax);
for (int wordIndex = 0; wordIndex < windowMatches.length; ++wordIndex) {
long bits = windowMatches[wordIndex];
windowMatches[wordIndex] = 0L;
while (bits != 0L) {
int ntz = Long.numberOfTrailingZeros(bits);
bits ^= 1L << ntz;
int index = wordIndex << 6 | ntz;
int doc = innerWindowMin + index;
double score = windowScores[index];
windowScores[index] = 0d;
scoreNonEssentialClauses(collector, doc, score);
}
}
}
private int updateMaxWindowScores(int windowMin) throws IOException {
// Only use essential scorers to compute the window's max doc ID, in order to avoid constantly
// recomputing max scores over small windows
@ -145,6 +197,11 @@ final class MaxScoreBulkScorer extends BulkScorer {
final int upTo = scorer.scorer.advanceShallow(Math.max(scorer.doc, windowMin));
windowMax = (int) Math.min(windowMax, upTo + 1L); // upTo is inclusive
}
// Score at least an entire inner window of docs
windowMax =
Math.max(
windowMax, (int) Math.min(Integer.MAX_VALUE, (long) windowMin + INNER_WINDOW_SIZE));
for (DisiWrapper scorer : allScorers) {
if (scorer.doc < windowMax) {
scorer.maxWindowScore = scorer.scorer.getMaxScore(windowMax - 1);
@ -155,6 +212,32 @@ final class MaxScoreBulkScorer extends BulkScorer {
return windowMax;
}
private void scoreNonEssentialClauses(LeafCollector collector, int doc, double essentialScore)
throws IOException {
double score = essentialScore;
for (int i = firstEssentialScorer - 1; i >= 0; --i) {
float maxPossibleScore = maxScorePropagator.scoreSumUpperBound(score + maxScoreSums[i]);
if (maxPossibleScore < minCompetitiveScore) {
// Hit is not competitive.
return;
} else if (maxScoreSums[i] == 0f) {
// Can break since scorers are sorted by ascending score.
break;
}
DisiWrapper scorer = allScorers[i];
if (scorer.doc < doc) {
scorer.doc = scorer.iterator.advance(doc);
}
if (scorer.doc == doc) {
score += scorer.scorer.score();
}
}
scorable.score = (float) score;
collector.collect(doc);
}
private boolean partitionScorers() {
Arrays.sort(allScorers, Comparator.comparingDouble(scorer -> scorer.maxWindowScore));
double maxScoreSum = 0;

View File

@ -49,6 +49,9 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
doc.add(new StringField("foo", value, Field.Store.NO));
}
w.addDocument(doc);
for (int i = 1; i < MaxScoreBulkScorer.INNER_WINDOW_SIZE; ++i) {
w.addDocument(new Document());
}
}
w.forceMerge(1);
}
@ -95,19 +98,19 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
assertEquals(2 + 1, scorer.score(), 0);
break;
case 1:
assertEquals(1, doc);
assertEquals(2048, doc);
assertEquals(2, scorer.score(), 0);
break;
case 2:
assertEquals(3, doc);
assertEquals(6144, doc);
assertEquals(2 + 1, scorer.score(), 0);
break;
case 3:
assertEquals(4, doc);
assertEquals(8192, doc);
assertEquals(1, scorer.score(), 0);
break;
case 4:
assertEquals(5, doc);
assertEquals(10240, doc);
assertEquals(1, scorer.score(), 0);
break;
default:
@ -162,13 +165,13 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
assertEquals(2 + 1, scorer.score(), 0);
break;
case 1:
assertEquals(1, doc);
assertEquals(2048, doc);
assertEquals(2, scorer.score(), 0);
// simulate top-2 retrieval
scorer.setMinCompetitiveScore(Math.nextUp(2));
break;
case 2:
assertEquals(3, doc);
assertEquals(6144, doc);
assertEquals(2 + 1, scorer.score(), 0);
scorer.setMinCompetitiveScore(Math.nextUp(2 + 1));
break;
@ -227,19 +230,19 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
assertEquals(2 + 1, scorer.score(), 0);
break;
case 1:
assertEquals(1, doc);
assertEquals(2048, doc);
assertEquals(2, scorer.score(), 0);
break;
case 2:
assertEquals(3, doc);
assertEquals(6144, doc);
assertEquals(2 + 1 + 3, scorer.score(), 0);
break;
case 3:
assertEquals(4, doc);
assertEquals(8192, doc);
assertEquals(1, scorer.score(), 0);
break;
case 4:
assertEquals(5, doc);
assertEquals(10240, doc);
assertEquals(1 + 3, scorer.score(), 0);
break;
default:
@ -297,18 +300,18 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
assertEquals(2 + 1, scorer.score(), 0);
break;
case 1:
assertEquals(1, doc);
assertEquals(2048, doc);
assertEquals(2, scorer.score(), 0);
// simulate top-2 retrieval
scorer.setMinCompetitiveScore(Math.nextUp(2));
break;
case 2:
assertEquals(3, doc);
assertEquals(6144, doc);
assertEquals(2 + 1 + 3, scorer.score(), 0);
scorer.setMinCompetitiveScore(Math.nextUp(2 + 1));
break;
case 3:
assertEquals(5, doc);
assertEquals(10240, doc);
assertEquals(1 + 3, scorer.score(), 0);
scorer.setMinCompetitiveScore(Math.nextUp(1 + 3));
break;