LUCENE:8770: BlockMaxConjunctionScorer now leverages two-phase iterators in order to avoid executing the second phase when scorers don't intersect

This commit is contained in:
jimczi 2019-05-21 11:35:44 +02:00
parent 0cb92993db
commit 4640a527a4
3 changed files with 80 additions and 69 deletions

View File

@ -49,6 +49,9 @@ Improvements
* LUCENE-7840: Non-scoring BooleanQuery now removes SHOULD clauses before building the scorer supplier
as opposed to eliminating them during scoring construction. (Atri Sharma via Jim Ferenczi)
* LUCENE-8770: BlockMaxConjunctionScorer now leverages two-phase iterators in order to avoid
executing the second phase when scorers don't intersect. (Adrien Grand, Jim Ferenczi)
======================= Lucene 8.1.1 =======================
(No Changes)

View File

@ -21,54 +21,80 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
/**
* Scorer for conjunctions that checks the maximum scores of each clause in
* order to potentially skip over blocks that can'h have competitive matches.
* order to potentially skip over blocks that can't have competitive matches.
*/
final class BlockMaxConjunctionScorer extends Scorer {
final Scorer[] scorers;
final DocIdSetIterator[] approximations;
final TwoPhaseIterator[] twoPhases;
final MaxScoreSumPropagator maxScorePropagator;
float minScore;
final double[] minScores; // stores the min value of the sum of scores between 0..i for a hit to be competitive
double score;
/** Create a new {@link BlockMaxConjunctionScorer} from scoring clauses. */
BlockMaxConjunctionScorer(Weight weight, Collection<Scorer> scorersList) throws IOException {
super(weight);
this.scorers = scorersList.toArray(new Scorer[scorersList.size()]);
for (Scorer scorer : scorers) {
// Sort scorer by cost
Arrays.sort(this.scorers, Comparator.comparingLong(s -> s.iterator().cost()));
this.maxScorePropagator = new MaxScoreSumPropagator(Arrays.asList(scorers));
this.approximations = new DocIdSetIterator[scorers.length];
List<TwoPhaseIterator> twoPhaseList = new ArrayList<>();
for (int i = 0; i < scorers.length; i++) {
Scorer scorer = scorers[i];
TwoPhaseIterator twoPhase = scorer.twoPhaseIterator();
if (twoPhase != null) {
twoPhaseList.add(twoPhase);
approximations[i] = twoPhase.approximation();
} else {
approximations[i] = scorer.iterator();
}
scorer.advanceShallow(0);
}
this.maxScorePropagator = new MaxScoreSumPropagator(scorersList);
this.twoPhases = twoPhaseList.toArray(new TwoPhaseIterator[twoPhaseList.size()]);
Arrays.sort(this.twoPhases, Comparator.comparingDouble(TwoPhaseIterator::matchCost));
}
// Put scorers with the higher max scores first
// We tie-break on cost
Comparator<Scorer> comparator = (s1, s2) -> {
int cmp;
try {
cmp = Float.compare(s2.getMaxScore(DocIdSetIterator.NO_MORE_DOCS), s1.getMaxScore(DocIdSetIterator.NO_MORE_DOCS));
} catch (IOException e) {
throw new RuntimeException(e);
@Override
public TwoPhaseIterator twoPhaseIterator() {
if (twoPhases.length == 0) {
return null;
}
if (cmp == 0) {
cmp = Long.compare(s1.iterator().cost(), s2.iterator().cost());
float matchCost = (float) Arrays.stream(twoPhases)
.mapToDouble(TwoPhaseIterator::matchCost)
.sum();
final DocIdSetIterator approx = approximation();
return new TwoPhaseIterator(approx) {
@Override
public boolean matches() throws IOException {
for (TwoPhaseIterator twoPhase : twoPhases) {
assert twoPhase.approximation().docID() == docID();
if (twoPhase.matches() == false) {
return false;
}
}
return true;
}
@Override
public float matchCost() {
return matchCost;
}
return cmp;
};
Arrays.sort(this.scorers, comparator);
minScores = new double[this.scorers.length];
}
@Override
public DocIdSetIterator iterator() {
// TODO: support two-phase
final Scorer leadScorer = this.scorers[0]; // higher max score
final DocIdSetIterator[] iterators = Arrays.stream(this.scorers)
.map(Scorer::iterator)
.toArray(DocIdSetIterator[]::new);
final DocIdSetIterator lead = iterators[0];
return twoPhases.length == 0 ? approximation() :
TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator());
}
private DocIdSetIterator approximation() {
final DocIdSetIterator lead = approximations[0];
return new DocIdSetIterator() {
@ -88,21 +114,6 @@ final class BlockMaxConjunctionScorer extends Scorer {
private void moveToNextBlock(int target) throws IOException {
upTo = advanceShallow(target);
maxScore = getMaxScore(upTo);
// Also compute the minimum required scores for a hit to be competitive
// A double that is less than 'score' might still be converted to 'score'
// when casted to a float, so we go to the previous float to avoid this issue
minScores[minScores.length - 1] = minScore > 0 ? Math.nextDown(minScore) : 0;
for (int i = scorers.length - 1; i > 0; --i) {
double minScore = minScores[i];
float clauseMaxScore = scorers[i].getMaxScore(upTo);
if (minScore > clauseMaxScore) {
minScores[i - 1] = minScore - clauseMaxScore;
assert minScores[i - 1] + clauseMaxScore <= minScore;
} else {
minScores[i - 1] = 0;
}
}
}
private int advanceTarget(int target) throws IOException {
@ -159,18 +170,9 @@ final class BlockMaxConjunctionScorer extends Scorer {
assert doc <= upTo;
if (minScore > 0) {
score = leadScorer.score();
if (score < minScores[0]) {
// computing a score is usually less costly than advancing other clauses
doc = lead.advance(advanceTarget(doc + 1));
continue;
}
}
// then find agreement with other iterators
for (int i = 1; i < iterators.length; ++i) {
final DocIdSetIterator other = iterators[i];
for (int i = 1; i < approximations.length; ++i) {
final DocIdSetIterator other = approximations[i];
// other.doc may already be equal to doc if we "continued advanceHead"
// on the previous iteration and the advance on the lead scorer exactly matched.
if (other.docID() < doc) {
@ -184,23 +186,6 @@ final class BlockMaxConjunctionScorer extends Scorer {
}
assert other.docID() == doc;
if (minScore > 0) {
score += scorers[i].score();
if (score < minScores[i]) {
// computing a score is usually less costly than advancing the next clause
doc = lead.advance(advanceTarget(doc + 1));
continue advanceHead;
}
}
}
if (minScore > 0 == false) {
// the score hasn't been computed on the fly, do it now
score = 0;
for (Scorer scorer : scorers) {
score += scorer.score();
}
}
// success - all iterators are on the same doc and the score is competitive
@ -217,6 +202,10 @@ final class BlockMaxConjunctionScorer extends Scorer {
@Override
public float score() throws IOException {
double score = 0;
for (Scorer scorer : scorers) {
score += scorer.score();
}
return (float) score;
}
@ -257,5 +246,4 @@ final class BlockMaxConjunctionScorer extends Scorer {
}
return children;
}
}

View File

@ -40,6 +40,14 @@ public class TestBlockMaxConjunction extends LuceneTestCase {
return query;
}
private Query maybeWrapTwoPhase(Query query) {
if (random().nextBoolean()) {
query = new RandomApproximationQuery(query, random());
query = new AssertingQuery(random(), query);
}
return query;
}
public void testRandom() throws IOException {
Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());
@ -75,6 +83,18 @@ public class TestBlockMaxConjunction extends LuceneTestCase {
.build();
CheckHits.checkTopScores(random(), filteredQuery, searcher);
builder = new BooleanQuery.Builder();
for (int i = 0; i < numClauses; ++i) {
builder.add(maybeWrapTwoPhase(new TermQuery(new Term("foo", Integer.toString(start + i)))), Occur.MUST);
}
Query twoPhaseQuery = new BooleanQuery.Builder()
.add(query, Occur.MUST)
.add(new TermQuery(new Term("foo", Integer.toString(filterTerm))), Occur.FILTER)
.build();
CheckHits.checkTopScores(random(), twoPhaseQuery, searcher);
}
reader.close();
dir.close();