Fix `DefaultBulkScorer` to not advance the competitive iterator beyond the end of the window. (#12481)

The way `DefaultBulkScorer` uses `ConjunctionDISI` may make it advance the
competitive iterator beyond the end of the window. This may cause bugs with
bulk scorers such as `BooleanScorer` that sometimes delegate to the single
clause that has matches in a given window of doc IDs. We should then make sure
to not advance the competitive iterator beyond the end of the window based on
this clause, as other clauses may have matches as well.
This commit is contained in:
Adrien Grand 2023-08-03 07:19:27 +02:00 committed by GitHub
parent acffcfaaf0
commit df3632cb03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 87 additions and 78 deletions

View File

@ -17,7 +17,6 @@
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
@ -228,32 +227,16 @@ public abstract class Weight implements SegmentCacheable {
collector.setScorer(scorer);
DocIdSetIterator scorerIterator = twoPhase == null ? iterator : twoPhase.approximation();
DocIdSetIterator competitiveIterator = collector.competitiveIterator();
DocIdSetIterator filteredIterator;
if (competitiveIterator == null) {
filteredIterator = scorerIterator;
} else {
// Wrap CompetitiveIterator and ScorerIterator start with (i.e., calling nextDoc()) the last
// visited docID because ConjunctionDISI might have advanced to it in the previous
// scoreRange, but we didn't process due to the range limit of scoreRange.
if (scorerIterator.docID() != -1) {
scorerIterator = new StartDISIWrapper(scorerIterator);
}
if (competitiveIterator.docID() != -1) {
competitiveIterator = new StartDISIWrapper(competitiveIterator);
}
// filter scorerIterator to keep only competitive docs as defined by collector
filteredIterator =
ConjunctionUtils.intersectIterators(Arrays.asList(scorerIterator, competitiveIterator));
}
if (filteredIterator.docID() == -1 && min == 0 && max == DocIdSetIterator.NO_MORE_DOCS) {
scoreAll(collector, filteredIterator, twoPhase, acceptDocs);
if (competitiveIterator == null
&& scorerIterator.docID() == -1
&& min == 0
&& max == DocIdSetIterator.NO_MORE_DOCS) {
scoreAll(collector, scorerIterator, twoPhase, acceptDocs);
return DocIdSetIterator.NO_MORE_DOCS;
} else {
int doc = filteredIterator.docID();
if (doc < min) {
doc = filteredIterator.advance(min);
}
return scoreRange(collector, filteredIterator, twoPhase, acceptDocs, doc, max);
return scoreRange(
collector, scorerIterator, twoPhase, competitiveIterator, acceptDocs, min, max);
}
}
@ -266,27 +249,59 @@ public abstract class Weight implements SegmentCacheable {
LeafCollector collector,
DocIdSetIterator iterator,
TwoPhaseIterator twoPhase,
DocIdSetIterator competitiveIterator,
Bits acceptDocs,
int currentDoc,
int end)
int min,
int max)
throws IOException {
if (twoPhase == null) {
while (currentDoc < end) {
if (acceptDocs == null || acceptDocs.get(currentDoc)) {
collector.collect(currentDoc);
}
currentDoc = iterator.nextDoc();
if (competitiveIterator != null) {
if (competitiveIterator.docID() > min) {
min = competitiveIterator.docID();
// The competitive iterator may not match any docs in the range.
min = Math.min(min, max);
}
return currentDoc;
} else {
while (currentDoc < end) {
if ((acceptDocs == null || acceptDocs.get(currentDoc)) && twoPhase.matches()) {
collector.collect(currentDoc);
}
currentDoc = iterator.nextDoc();
}
return currentDoc;
}
int doc = iterator.docID();
if (doc < min) {
if (doc == min - 1) {
doc = iterator.nextDoc();
} else {
doc = iterator.advance(min);
}
}
if (twoPhase == null && competitiveIterator == null) {
// Optimize simple iterators with collectors that can't skip
while (doc < max) {
if (acceptDocs == null || acceptDocs.get(doc)) {
collector.collect(doc);
}
doc = iterator.nextDoc();
}
} else {
while (doc < max) {
if (competitiveIterator != null) {
assert competitiveIterator.docID() <= doc;
if (competitiveIterator.docID() < doc) {
competitiveIterator.advance(doc);
}
if (competitiveIterator.docID() != doc) {
doc = iterator.advance(competitiveIterator.docID());
continue;
}
}
if ((acceptDocs == null || acceptDocs.get(doc))
&& (twoPhase == null || twoPhase.matches())) {
collector.collect(doc);
}
doc = iterator.nextDoc();
}
}
return doc;
}
/**
@ -320,39 +335,4 @@ public abstract class Weight implements SegmentCacheable {
}
}
}
/** Wraps an internal docIdSetIterator for it to start with the last visited docID */
private static class StartDISIWrapper extends DocIdSetIterator {
private final DocIdSetIterator in;
private final int startDocID;
private int docID = -1;
StartDISIWrapper(DocIdSetIterator in) {
this.in = in;
this.startDocID = in.docID();
}
@Override
public int docID() {
return docID;
}
@Override
public int nextDoc() throws IOException {
return advance(docID + 1);
}
@Override
public int advance(int target) throws IOException {
if (target <= startDocID) {
return docID = startDocID;
}
return docID = in.advance(target);
}
@Override
public long cost() {
return in.cost();
}
}
}

View File

@ -53,7 +53,36 @@ class AssertingLeafCollector extends FilterLeafCollector {
@Override
public DocIdSetIterator competitiveIterator() throws IOException {
return in.competitiveIterator();
final DocIdSetIterator in = this.in.competitiveIterator();
if (in == null) {
return null;
}
return new DocIdSetIterator() {
@Override
public int nextDoc() throws IOException {
assert in.docID() < max
: "advancing beyond the end of the scored window: docID=" + in.docID() + ", max=" + max;
return in.nextDoc();
}
@Override
public int docID() {
return in.docID();
}
@Override
public long cost() {
return in.cost();
}
@Override
public int advance(int target) throws IOException {
assert target <= max
: "advancing beyond the end of the scored window: target=" + target + ", max=" + max;
return in.advance(target);
}
};
}
@Override