Make WANDScorer compute scores on the fly. (#14021)

Currently, `WANDSCorer` considers that a hit is a match if the sum of maximum
scores across clauses is more than or equal to the minimum competitive score.
We can do better by computing scores of leading clauses on the fly. This helps
because scores are often lower than the score upper bound, so using actual
scores instead of score upper bounds can help skip advancing more clauses.

For reference, we are already doing the same trick in our conjunction (bulk)
scorers and in `MaxScoreBulkScorer` (bulk scorer for top-level disjunctions).
This commit is contained in:
Adrien Grand 2024-11-26 16:00:30 +01:00
parent 71715b59e8
commit 3f620dcd34
2 changed files with 59 additions and 18 deletions

View File

@ -81,6 +81,9 @@ Optimizations
* GITHUB#13989: Faster checksum computation. (Jean-François Boeuf)
* GITHUB#14021: WANDScorer now computes scores on the fly, which helps prevent
advancing "tail" clauses in many cases. (Adrien Grand)
Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended

View File

@ -24,6 +24,7 @@ import static org.apache.lucene.search.ScorerUtil.costWithMinShouldMatch;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.lucene.util.MathUtil;
@ -129,7 +130,7 @@ final class WANDScorer extends Scorer {
// some descriptions of WAND (Weak AND).
DisiWrapper lead;
int doc; // current doc ID of the leads
long leadMaxScore; // sum of the max scores of scorers in 'lead'
double leadScore; // score of the leads
// priority queue of scorers that are too advanced compared to the current
// doc. Ordered by doc ID.
@ -195,7 +196,7 @@ final class WANDScorer extends Scorer {
}
for (Scorer scorer : scorers) {
addLead(new DisiWrapper(scorer));
addUnpositionedLead(new DisiWrapper(scorer));
}
this.cost =
@ -208,7 +209,7 @@ final class WANDScorer extends Scorer {
// returns a boolean so that it can be called from assert
// the return value is useless: it always returns true
private boolean ensureConsistent() {
private boolean ensureConsistent() throws IOException {
if (scoreMode == ScoreMode.TOP_SCORES) {
long maxScoreSum = 0;
for (int i = 0; i < tailSize; ++i) {
@ -217,12 +218,19 @@ final class WANDScorer extends Scorer {
}
assert maxScoreSum == tailMaxScore : maxScoreSum + " " + tailMaxScore;
maxScoreSum = 0;
List<Float> leadScores = new ArrayList<>();
for (DisiWrapper w = lead; w != null; w = w.next) {
assert w.doc == doc;
maxScoreSum = Math.addExact(maxScoreSum, w.scaledMaxScore);
leadScores.add(w.scorer.score());
}
assert maxScoreSum == leadMaxScore : maxScoreSum + " " + leadMaxScore;
// Make sure to recompute the sum in the same order to get the same floating point rounding
// errors.
Collections.reverse(leadScores);
double recomputedLeadScore = 0;
for (float score : leadScores) {
recomputedLeadScore += score;
}
assert recomputedLeadScore == leadScore;
assert minCompetitiveScore == 0
|| tailMaxScore < minCompetitiveScore
@ -285,8 +293,6 @@ final class WANDScorer extends Scorer {
@Override
public int advance(int target) throws IOException {
assert ensureConsistent();
// Move 'lead' iterators back to the tail
pushBackLeads(target);
@ -319,17 +325,34 @@ final class WANDScorer extends Scorer {
assert lead == null;
moveToNextCandidate();
while (leadMaxScore < minCompetitiveScore || freq < minShouldMatch) {
if (leadMaxScore + tailMaxScore < minCompetitiveScore
long scaledLeadScore = 0;
if (scoreMode == ScoreMode.TOP_SCORES) {
scaledLeadScore =
scaleMaxScore(
(float) MathUtil.sumUpperBound(leadScore, FLOAT_MANTISSA_BITS), scalingFactor);
}
while (scaledLeadScore < minCompetitiveScore || freq < minShouldMatch) {
assert ensureConsistent();
if (scaledLeadScore + tailMaxScore < minCompetitiveScore
|| freq + tailSize < minShouldMatch) {
return false;
} else {
// a match on doc is still possible, try to
// advance scorers from the tail
DisiWrapper prevLead = lead;
advanceTail();
if (scoreMode == ScoreMode.TOP_SCORES && lead != prevLead) {
assert prevLead == lead.next;
scaledLeadScore =
scaleMaxScore(
(float) MathUtil.sumUpperBound(leadScore, FLOAT_MANTISSA_BITS),
scalingFactor);
}
}
}
assert ensureConsistent();
return true;
}
@ -342,10 +365,20 @@ final class WANDScorer extends Scorer {
}
/** Add a disi to the linked list of leads. */
private void addLead(DisiWrapper lead) {
private void addLead(DisiWrapper lead) throws IOException {
lead.next = this.lead;
this.lead = lead;
freq += 1;
if (scoreMode == ScoreMode.TOP_SCORES) {
leadScore += lead.scorer.score();
}
}
/** Add a disi to the linked list of leads. */
private void addUnpositionedLead(DisiWrapper lead) {
assert lead.doc == -1;
lead.next = this.lead;
this.lead = lead;
leadMaxScore += lead.scaledMaxScore;
freq += 1;
}
@ -359,7 +392,6 @@ final class WANDScorer extends Scorer {
}
}
lead = null;
leadMaxScore = 0;
}
/** Make sure all disis in 'head' are on or after 'target'. */
@ -488,8 +520,10 @@ final class WANDScorer extends Scorer {
lead = head.pop();
assert doc == lead.doc;
lead.next = null;
leadMaxScore = lead.scaledMaxScore;
freq = 1;
if (scoreMode == ScoreMode.TOP_SCORES) {
leadScore = lead.scorer.score();
}
while (head.size() > 0 && head.top().doc == doc) {
addLead(head.pop());
}
@ -514,11 +548,15 @@ final class WANDScorer extends Scorer {
public float score() throws IOException {
// we need to know about all matches
advanceAllTail();
double score = 0;
for (DisiWrapper s = lead; s != null; s = s.next) {
score += s.scorer.score();
double leadScore = this.leadScore;
if (scoreMode != ScoreMode.TOP_SCORES) {
// With TOP_SCORES, the score was already computed on the fly.
for (DisiWrapper s = lead; s != null; s = s.next) {
leadScore += s.scorer.score();
}
}
return (float) score;
return (float) leadScore;
}
@Override