Run filtered disjunctions with MaxScoreBulkScorer. (#14014)

Running filtered disjunctions with a specialized bulk scorer seems to yield a
good speedup. For what it's worth, I also tried to implement a MAXSCORE-based
scorer to see if it had to do with the `BulkScorer` specialization or the
algorithm, but it didn't help.

To work properly, I had to add a rewrite rule to inline disjunctions in a MUST
clause.

As a next step, it would be interesting to see if we can further optimize this
by loading the filter into a bitset and applying it like live docs.
This commit is contained in:
Adrien Grand 2024-11-27 21:56:03 +01:00 committed by GitHub
parent d9aa525c9e
commit 98c59a710e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 304 additions and 14 deletions

View File

@ -110,6 +110,9 @@ Optimizations
* GITHUB#14021: WANDScorer now computes scores on the fly, which helps prevent
advancing "tail" clauses in many cases. (Adrien Grand)
* GITHUB#14014: Filtered disjunctions now get executed via `MaxScoreBulkScorer`.
(Adrien Grand)
Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended

View File

@ -624,6 +624,26 @@ public class BooleanQuery extends Query implements Iterable<BooleanClause> {
}
}
// Inline SHOULD clauses from the only MUST clause
{
if (clauseSets.get(Occur.SHOULD).isEmpty()
&& clauseSets.get(Occur.MUST).size() == 1
&& clauseSets.get(Occur.MUST).iterator().next() instanceof BooleanQuery inner
&& inner.clauses.size() == inner.clauseSets.get(Occur.SHOULD).size()) {
BooleanQuery.Builder rewritten = new BooleanQuery.Builder();
for (BooleanClause clause : clauses) {
if (clause.occur() != Occur.MUST) {
rewritten.add(clause);
}
}
for (BooleanClause innerClause : inner.clauses()) {
rewritten.add(innerClause);
}
rewritten.setMinimumNumberShouldMatch(Math.max(1, inner.getMinimumNumberShouldMatch()));
return rewritten.build();
}
}
return super.rewrite(indexSearcher);
}

View File

@ -183,7 +183,8 @@ final class BooleanScorerSupplier extends ScorerSupplier {
BulkScorer booleanScorer() throws IOException {
final int numOptionalClauses = subs.get(Occur.SHOULD).size();
final int numRequiredClauses = subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size();
final int numMustClauses = subs.get(Occur.MUST).size();
final int numRequiredClauses = numMustClauses + subs.get(Occur.FILTER).size();
BulkScorer positiveScorer;
if (numRequiredClauses == 0) {
@ -209,6 +210,8 @@ final class BooleanScorerSupplier extends ScorerSupplier {
}
positiveScorer = optionalBulkScorer();
} else if (numMustClauses == 0 && numOptionalClauses > 1 && minShouldMatch >= 1) {
positiveScorer = filteredOptionalBulkScorer();
} else if (numRequiredClauses > 0 && numOptionalClauses == 0 && minShouldMatch == 0) {
positiveScorer = requiredBulkScorer();
} else {
@ -286,7 +289,7 @@ final class BooleanScorerSupplier extends ScorerSupplier {
optionalScorers.add(ss.get(Long.MAX_VALUE));
}
return new MaxScoreBulkScorer(maxDoc, optionalScorers);
return new MaxScoreBulkScorer(maxDoc, optionalScorers, null);
}
List<Scorer> optional = new ArrayList<Scorer>();
@ -297,6 +300,32 @@ final class BooleanScorerSupplier extends ScorerSupplier {
return new BooleanScorer(optional, Math.max(1, minShouldMatch), scoreMode.needsScores());
}
BulkScorer filteredOptionalBulkScorer() throws IOException {
if (subs.get(Occur.MUST).isEmpty() == false
|| subs.get(Occur.FILTER).isEmpty()
|| scoreMode != ScoreMode.TOP_SCORES
|| subs.get(Occur.SHOULD).size() <= 1
|| minShouldMatch > 1) {
return null;
}
long cost = cost();
List<Scorer> optionalScorers = new ArrayList<>();
for (ScorerSupplier ss : subs.get(Occur.SHOULD)) {
optionalScorers.add(ss.get(cost));
}
List<Scorer> filters = new ArrayList<>();
for (ScorerSupplier ss : subs.get(Occur.FILTER)) {
filters.add(ss.get(cost));
}
Scorer filterScorer;
if (filters.size() == 1) {
filterScorer = filters.iterator().next();
} else {
filterScorer = new ConjunctionScorer(filters, Collections.emptySet());
}
return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer);
}
// Return a BulkScorer for the required clauses only
private BulkScorer requiredBulkScorer() throws IOException {
if (subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() == 0) {

View File

@ -46,12 +46,14 @@ final class MaxScoreBulkScorer extends BulkScorer {
float minCompetitiveScore;
private final Score scorable = new Score();
final double[] maxScoreSums;
private final DisiWrapper filter;
private final long[] windowMatches = new long[FixedBitSet.bits2words(INNER_WINDOW_SIZE)];
private final double[] windowScores = new double[INNER_WINDOW_SIZE];
MaxScoreBulkScorer(int maxDoc, List<Scorer> scorers) throws IOException {
MaxScoreBulkScorer(int maxDoc, List<Scorer> scorers, Scorer filter) throws IOException {
this.maxDoc = maxDoc;
this.filter = filter == null ? null : new DisiWrapper(filter);
allScorers = new DisiWrapper[scorers.size()];
scratch = new DisiWrapper[allScorers.length];
int i = 0;
@ -123,7 +125,7 @@ final class MaxScoreBulkScorer extends BulkScorer {
}
while (top.doc < outerWindowMax) {
scoreInnerWindow(collector, acceptDocs, outerWindowMax);
scoreInnerWindow(collector, acceptDocs, outerWindowMax, filter);
top = essentialQueue.top();
if (minCompetitiveScore >= nextMinCompetitiveScore) {
// The minimum competitive score increased substantially, so we can now partition scorers
@ -139,9 +141,11 @@ final class MaxScoreBulkScorer extends BulkScorer {
return nextCandidate(max);
}
private void scoreInnerWindow(LeafCollector collector, Bits acceptDocs, int max)
throws IOException {
if (allScorers.length - firstRequiredScorer >= 2) {
private void scoreInnerWindow(
LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException {
if (filter != null) {
scoreInnerWindowWithFilter(collector, acceptDocs, max, filter);
} else if (allScorers.length - firstRequiredScorer >= 2) {
scoreInnerWindowAsConjunction(collector, acceptDocs, max);
} else {
DisiWrapper top = essentialQueue.top();
@ -158,6 +162,55 @@ final class MaxScoreBulkScorer extends BulkScorer {
}
}
private void scoreInnerWindowWithFilter(
LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException {
// TODO: Sometimes load the filter into a bitset and use the more optimized execution paths with
// this bitset as `acceptDocs`
DisiWrapper top = essentialQueue.top();
assert top.doc < max;
if (top.doc < filter.doc) {
top.doc = top.approximation.advance(filter.doc);
}
// Only score an inner window, after that we'll check if the min competitive score has increased
// enough for a more favorable partitioning to be used.
int innerWindowMin = top.doc;
int innerWindowMax = (int) Math.min(max, (long) innerWindowMin + INNER_WINDOW_SIZE);
while (top.doc < innerWindowMax) {
assert filter.doc <= top.doc; // invariant
if (filter.doc < top.doc) {
filter.doc = filter.approximation.advance(top.doc);
}
if (filter.doc != top.doc) {
do {
top.doc = top.iterator.advance(filter.doc);
top = essentialQueue.updateTop();
} while (top.doc < filter.doc);
} else {
int doc = top.doc;
boolean match =
(acceptDocs == null || acceptDocs.get(doc))
&& (filter.twoPhaseView == null || filter.twoPhaseView.matches());
double score = 0;
do {
if (match) {
score += top.scorer.score();
}
top.doc = top.iterator.nextDoc();
top = essentialQueue.updateTop();
} while (top.doc == doc);
if (match) {
scoreNonEssentialClauses(collector, doc, score, firstEssentialScorer);
}
}
}
}
private void scoreInnerWindowSingleEssentialClause(
LeafCollector collector, Bits acceptDocs, int upTo) throws IOException {
DisiWrapper top = essentialQueue.top();
@ -284,9 +337,11 @@ final class MaxScoreBulkScorer extends BulkScorer {
int windowMax = DocIdSetIterator.NO_MORE_DOCS;
for (int i = firstWindowLead; i < allScorers.length; ++i) {
final DisiWrapper scorer = allScorers[i];
if (filter == null || scorer.cost >= filter.cost) {
final int upTo = scorer.scorer.advanceShallow(Math.max(scorer.doc, windowMin));
windowMax = (int) Math.min(windowMax, upTo + 1L); // upTo is inclusive
}
}
if (allScorers.length - firstWindowLead > 1) {
// The more clauses we consider to compute outer windows, the higher chances that one of these

View File

@ -792,6 +792,51 @@ public class TestBooleanRewrites extends LuceneTestCase {
assertEquals(expectedRewritten, searcher.rewrite(query));
}
public void testFlattenDisjunctionInMustClause() throws IOException {
IndexSearcher searcher = newSearcher(new MultiReader());
Query inner =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD)
.build();
Query query =
new BooleanQuery.Builder()
.add(inner, Occur.MUST)
.add(new TermQuery(new Term("foo", "baz")), Occur.FILTER)
.build();
Query expectedRewritten =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "baz")), Occur.FILTER)
.setMinimumNumberShouldMatch(1)
.build();
assertEquals(expectedRewritten, searcher.rewrite(query));
inner =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "foo")), Occur.SHOULD)
.setMinimumNumberShouldMatch(2)
.build();
query =
new BooleanQuery.Builder()
.add(inner, Occur.MUST)
.add(new TermQuery(new Term("foo", "baz")), Occur.FILTER)
.build();
expectedRewritten =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "foo")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "baz")), Occur.FILTER)
.setMinimumNumberShouldMatch(2)
.build();
assertEquals(expectedRewritten, searcher.rewrite(query));
}
public void testDiscardShouldClauses() throws IOException {
IndexSearcher searcher = newSearcher(new MultiReader());

View File

@ -85,7 +85,8 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
.scorer(context);
BulkScorer scorer =
new MaxScoreBulkScorer(context.reader().maxDoc(), Arrays.asList(scorer1, scorer2));
new MaxScoreBulkScorer(
context.reader().maxDoc(), Arrays.asList(scorer1, scorer2), null);
scorer.score(
new LeafCollector() {
@ -134,6 +135,141 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
}
}
public void testFilteredDisjunction() throws Exception {
try (Directory dir = newDirectory()) {
writeDocuments(dir);
try (IndexReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = newSearcher(reader);
Query clause1 =
new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2);
Query clause2 = new ConstantScoreQuery(new TermQuery(new Term("foo", "C")));
Query filter = new TermQuery(new Term("foo", "B"));
LeafReaderContext context = searcher.getIndexReader().leaves().get(0);
Scorer scorer1 =
searcher
.createWeight(searcher.rewrite(clause1), ScoreMode.TOP_SCORES, 1f)
.scorer(context);
Scorer scorer2 =
searcher
.createWeight(searcher.rewrite(clause2), ScoreMode.TOP_SCORES, 1f)
.scorer(context);
Scorer filterScorer =
searcher
.createWeight(searcher.rewrite(filter), ScoreMode.TOP_SCORES, 1f)
.scorer(context);
BulkScorer scorer =
new MaxScoreBulkScorer(
context.reader().maxDoc(), Arrays.asList(scorer1, scorer2), filterScorer);
scorer.score(
new LeafCollector() {
private int i;
private Scorable scorer;
@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
}
@Override
public void collect(int doc) throws IOException {
switch (i++) {
case 0:
assertEquals(0, doc);
assertEquals(2, scorer.score(), 0);
break;
case 1:
assertEquals(12288, doc);
assertEquals(2 + 1, scorer.score(), 0);
break;
case 2:
assertEquals(20480, doc);
assertEquals(1, scorer.score(), 0);
break;
default:
fail();
break;
}
}
},
null,
0,
DocIdSetIterator.NO_MORE_DOCS);
}
}
}
public void testFilteredDisjunctionWithSkipping() throws Exception {
try (Directory dir = newDirectory()) {
writeDocuments(dir);
try (IndexReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = newSearcher(reader);
Query clause1 =
new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2);
Query clause2 = new ConstantScoreQuery(new TermQuery(new Term("foo", "C")));
Query filter = new TermQuery(new Term("foo", "B"));
LeafReaderContext context = searcher.getIndexReader().leaves().get(0);
Scorer scorer1 =
searcher
.createWeight(searcher.rewrite(clause1), ScoreMode.TOP_SCORES, 1f)
.scorer(context);
Scorer scorer2 =
searcher
.createWeight(searcher.rewrite(clause2), ScoreMode.TOP_SCORES, 1f)
.scorer(context);
Scorer filterScorer =
searcher
.createWeight(searcher.rewrite(filter), ScoreMode.TOP_SCORES, 1f)
.scorer(context);
BulkScorer scorer =
new MaxScoreBulkScorer(
context.reader().maxDoc(), Arrays.asList(scorer1, scorer2), filterScorer);
scorer.score(
new LeafCollector() {
private int i;
private Scorable scorer;
@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
}
@Override
public void collect(int doc) throws IOException {
switch (i++) {
case 0:
assertEquals(0, doc);
assertEquals(2, scorer.score(), 0);
scorer.setMinCompetitiveScore(Math.nextUp(2));
break;
case 1:
assertEquals(12288, doc);
assertEquals(2 + 1, scorer.score(), 0);
scorer.setMinCompetitiveScore(Math.nextUp(2 + 1));
break;
default:
System.out.println(i);
fail();
break;
}
}
},
null,
0,
DocIdSetIterator.NO_MORE_DOCS);
}
}
}
public void testBasicsWithTwoDisjunctionClausesAndSkipping() throws Exception {
try (Directory dir = newDirectory()) {
writeDocuments(dir);
@ -155,7 +291,8 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
.scorer(context);
BulkScorer scorer =
new MaxScoreBulkScorer(context.reader().maxDoc(), Arrays.asList(scorer1, scorer2));
new MaxScoreBulkScorer(
context.reader().maxDoc(), Arrays.asList(scorer1, scorer2), null);
scorer.score(
new LeafCollector() {
@ -227,7 +364,7 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
BulkScorer scorer =
new MaxScoreBulkScorer(
context.reader().maxDoc(), Arrays.asList(scorer1, scorer2, scorer3));
context.reader().maxDoc(), Arrays.asList(scorer1, scorer2, scorer3), null);
scorer.score(
new LeafCollector() {
@ -304,7 +441,7 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
BulkScorer scorer =
new MaxScoreBulkScorer(
context.reader().maxDoc(), Arrays.asList(scorer1, scorer2, scorer3));
context.reader().maxDoc(), Arrays.asList(scorer1, scorer2, scorer3), null);
scorer.score(
new LeafCollector() {
@ -505,7 +642,8 @@ public class TestMaxScoreBulkScorer extends LuceneTestCase {
fox.cost = 900;
fox.maxScore = 1.1f;
MaxScoreBulkScorer scorer = new MaxScoreBulkScorer(10_000, Arrays.asList(the, quick, fox));
MaxScoreBulkScorer scorer =
new MaxScoreBulkScorer(10_000, Arrays.asList(the, quick, fox), null);
the.docID = 4;
the.maxScoreUpTo = 130;
quick.docID = 4;