Run top-level conjunctions of term queries with a specialized BulkScorer. (#12382)

This implements a specialized `BlockMaxConjunctionBulkScorer`, which is really
the same as `BlockMaxConjunctionScorer`, but as a `BulkScorer` instead of a
`Scorer`. Also it doesn't support two-phase iterators in order to focus on the
common case when queries, such as term queries, do not have two-phase
iterators. If a clause has a two-phase iterator, it will keep running as a
`BlockMaxConjunctionScorer` wrapped in a `DefaultBulkScorer`.
This commit is contained in:
Adrien Grand 2023-09-25 13:36:44 +02:00 committed by GitHub
parent d48913a957
commit f2bd0bbcdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 276 additions and 22 deletions

View File

@ -96,6 +96,9 @@ Optimizations
* GITHUB#12552: Make FSTPostingsFormat load FSTs off-heap. (Tony X)
* GITHUB#12382: Faster top-level conjunctions on term queries when sorting by
descending score. (Adrien Grand)
Bug Fixes
---------------------

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.search.Weight.DefaultBulkScorer;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.MathUtil;
/**
* BulkScorer implementation of {@link BlockMaxConjunctionScorer} that focuses on top-level
* conjunctions over clauses that do not have two-phase iterators. Use a {@link DefaultBulkScorer}
* around a {@link BlockMaxConjunctionScorer} if you need two-phase support. Another difference with
* {@link BlockMaxConjunctionScorer} is that this scorer computes scores on the fly in order to be
* able to skip evaluating more clauses if the total score would be under the minimum competitive
* score anyway. This generally works well because computing a score is cheaper than decoding a
* block of postings.
*/
final class BlockMaxConjunctionBulkScorer extends BulkScorer {
private final Scorer[] scorers;
private final DocIdSetIterator[] iterators;
private final DocIdSetIterator lead;
private final DocAndScore scorable = new DocAndScore();
private final double[] sumOfOtherClauses;
BlockMaxConjunctionBulkScorer(List<Scorer> scorers) throws IOException {
if (scorers.size() <= 1) {
throw new IllegalArgumentException("Expected 2 or more scorers, got " + scorers.size());
}
this.scorers = scorers.toArray(Scorer[]::new);
Arrays.sort(this.scorers, Comparator.comparingLong(scorer -> scorer.iterator().cost()));
this.iterators =
Arrays.stream(this.scorers).map(Scorer::iterator).toArray(DocIdSetIterator[]::new);
lead = iterators[0];
this.sumOfOtherClauses = new double[this.scorers.length];
}
@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
collector.setScorer(scorable);
int windowMin = Math.max(lead.docID(), min);
while (windowMin < max) {
// Use impacts of the least costly scorer to compute windows
// NOTE: windowMax is inclusive
int windowMax = Math.min(scorers[0].advanceShallow(windowMin), max - 1);
for (int i = 1; i < scorers.length; ++i) {
scorers[i].advanceShallow(windowMin);
}
double maxWindowScore = 0;
for (int i = 0; i < scorers.length; ++i) {
double maxClauseScore = scorers[i].getMaxScore(windowMax);
sumOfOtherClauses[i] = maxClauseScore;
maxWindowScore += maxClauseScore;
}
for (int i = sumOfOtherClauses.length - 2; i >= 0; --i) {
sumOfOtherClauses[i] += sumOfOtherClauses[i + 1];
}
scoreWindow(collector, acceptDocs, windowMin, windowMax + 1, (float) maxWindowScore);
windowMin = Math.max(lead.docID(), windowMax + 1);
}
return windowMin;
}
private void scoreWindow(
LeafCollector collector, Bits acceptDocs, int min, int max, float maxWindowScore)
throws IOException {
if (maxWindowScore < scorable.minCompetitiveScore) {
// no hits are competitive
return;
}
if (lead.docID() < min) {
lead.advance(min);
}
advanceHead:
for (int doc = lead.docID(); doc < max; ) {
if (acceptDocs != null && acceptDocs.get(doc) == false) {
doc = lead.nextDoc();
continue;
}
// Compute the score as we find more matching clauses, in order to skip advancing other
// clauses if the total score has no chance of being competitive. This works well because
// computing a score is usually cheaper than decoding a full block of postings and
// frequencies.
final boolean hasMinCompetitiveScore = scorable.minCompetitiveScore > 0;
double currentScore;
if (hasMinCompetitiveScore) {
currentScore = scorers[0].score();
} else {
currentScore = 0;
}
for (int i = 1; i < iterators.length; ++i) {
// First check if we have a chance of having a match
if (hasMinCompetitiveScore
&& MathUtil.sumUpperBound(currentScore + sumOfOtherClauses[i], scorers.length)
< scorable.minCompetitiveScore) {
doc = lead.nextDoc();
continue advanceHead;
}
// NOTE: these iterators may already be on `doc` already if we called `continue advanceHead`
// on the previous loop iteration.
if (iterators[i].docID() < doc) {
int next = iterators[i].advance(doc);
if (next != doc) {
doc = lead.advance(next);
continue advanceHead;
}
}
assert iterators[i].docID() == doc;
if (hasMinCompetitiveScore) {
currentScore += scorers[i].score();
}
}
if (hasMinCompetitiveScore == false) {
for (Scorer scorer : scorers) {
currentScore += scorer.score();
}
}
scorable.score = (float) currentScore;
collector.collect(doc);
// The collect() call may have updated the minimum competitive score.
if (maxWindowScore < scorable.minCompetitiveScore) {
// no more hits are competitive
return;
}
doc = lead.nextDoc();
}
}
@Override
public long cost() {
return lead.cost();
}
private static class DocAndScore extends Scorable {
float score;
float minCompetitiveScore;
@Override
public float score() throws IOException {
return score;
}
@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
this.minCompetitiveScore = minScore;
}
}
}

View File

@ -19,9 +19,11 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.similarities.Similarity;
@ -157,6 +159,7 @@ final class BooleanWeight extends Weight {
}
static BulkScorer disableScoring(final BulkScorer scorer) {
Objects.requireNonNull(scorer);
return new BulkScorer() {
@Override
@ -250,31 +253,103 @@ final class BooleanWeight extends Weight {
this, optional, Math.max(1, query.getMinimumNumberShouldMatch()), scoreMode.needsScores());
}
// Return a BulkScorer for the required clauses only,
// or null if it is not applicable
// Return a BulkScorer for the required clauses only
private BulkScorer requiredBulkScorer(LeafReaderContext context) throws IOException {
BulkScorer scorer = null;
// Is there a single required clause by any chance? Then pull its bulk scorer.
List<WeightedBooleanClause> requiredClauses = new ArrayList<>();
for (WeightedBooleanClause wc : weightedClauses) {
Weight w = wc.weight;
BooleanClause c = wc.clause;
if (c.isRequired() == false) {
continue;
}
if (scorer != null) {
// we don't have a BulkScorer for conjunctions
return null;
}
scorer = w.bulkScorer(context);
if (scorer == null) {
// no matches
return null;
}
if (c.isScoring() == false && scoreMode.needsScores()) {
scorer = disableScoring(scorer);
if (wc.clause.isRequired()) {
requiredClauses.add(wc);
}
}
return scorer;
if (requiredClauses.isEmpty()) {
// No required clauses at all.
return null;
} else if (requiredClauses.size() == 1) {
WeightedBooleanClause clause = requiredClauses.get(0);
BulkScorer scorer = clause.weight.bulkScorer(context);
if (scorer == null) {
return null;
}
if (clause.clause.isScoring() == false && scoreMode.needsScores()) {
scorer = disableScoring(scorer);
}
return scorer;
}
List<ScorerSupplier> requiredNoScoringSupplier = new ArrayList<>();
List<ScorerSupplier> requiredScoringSupplier = new ArrayList<>();
long leadCost = Long.MAX_VALUE;
for (WeightedBooleanClause wc : requiredClauses) {
Weight w = wc.weight;
BooleanClause c = wc.clause;
ScorerSupplier scorerSupplier = w.scorerSupplier(context);
if (scorerSupplier == null) {
// One clause doesn't have matches, so the entire conjunction doesn't have matches.
return null;
}
leadCost = Math.min(leadCost, scorerSupplier.cost());
if (c.isScoring() && scoreMode.needsScores()) {
requiredScoringSupplier.add(scorerSupplier);
} else {
requiredNoScoringSupplier.add(scorerSupplier);
}
}
List<Scorer> requiredNoScoring = new ArrayList<>();
for (ScorerSupplier ss : requiredNoScoringSupplier) {
requiredNoScoring.add(ss.get(leadCost));
}
List<Scorer> requiredScoring = new ArrayList<>();
for (ScorerSupplier ss : requiredScoringSupplier) {
if (requiredScoringSupplier.size() == 1) {
ss.setTopLevelScoringClause();
}
requiredScoring.add(ss.get(leadCost));
}
if (scoreMode == ScoreMode.TOP_SCORES
&& requiredNoScoringSupplier.isEmpty()
&& requiredScoring.size() > 1
// Only specialize top-level conjunctions for clauses that don't have a two-phase iterator.
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
return new BlockMaxConjunctionBulkScorer(requiredScoring);
}
if (scoreMode == ScoreMode.TOP_SCORES && requiredScoring.size() > 1) {
requiredScoring =
Collections.singletonList(new BlockMaxConjunctionScorer(this, requiredScoring));
}
Scorer conjunctionScorer;
if (requiredNoScoring.size() + requiredScoring.size() == 1) {
if (requiredScoring.size() == 1) {
conjunctionScorer = requiredScoring.get(0);
} else {
conjunctionScorer = requiredNoScoring.get(0);
if (scoreMode.needsScores()) {
Scorer inner = conjunctionScorer;
conjunctionScorer =
new FilterScorer(inner) {
@Override
public float score() throws IOException {
return 0f;
}
@Override
public float getMaxScore(int upTo) throws IOException {
return 0f;
}
};
}
}
} else {
List<Scorer> required = new ArrayList<>();
required.addAll(requiredScoring);
required.addAll(requiredNoScoring);
conjunctionScorer = new ConjunctionScorer(this, required, requiredScoring);
}
return new DefaultBulkScorer(conjunctionScorer);
}
/**
@ -314,7 +389,7 @@ final class BooleanWeight extends Weight {
return null;
}
} else if (numRequiredClauses == 1
} else if (numRequiredClauses > 0
&& numOptionalClauses == 0
&& query.getMinimumNumberShouldMatch() == 0) {
positiveScorer = requiredBulkScorer(context);