Introduce a BulkScorer for DisjunctionMaxQuery. (#14040)

This introduces a bulk scorer for `DisjunctionMaxQuery` that delegates to the
bulk scorers of the query clauses. This helps make the performance of top-level
`DisjunctionMaxQuery` better, especially when its clauses have optimized bulk
scorers themselves (e.g. disjunctions).
This commit is contained in:
Adrien Grand 2024-12-06 11:01:01 +01:00 committed by GitHub
parent 8103f2a44a
commit c88f9334e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 190 additions and 0 deletions

View File

@ -130,6 +130,9 @@ Optimizations
* GITHUB#14032: Speed up PostingsEnum when positions are requested.
(Adrien Grand)
* GITHUB#14040: Specialized top-level DisjunctionMaxQuery evaluation when the
tie break multiplier is 0. (Adrien Grand)
Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended

View File

@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.PriorityQueue;
/** Bulk scorer for {@link DisjunctionMaxQuery} when the tie-break multiplier is zero. */
final class DisjunctionMaxBulkScorer extends BulkScorer {
// Same window size as BooleanScorer
private static final int WINDOW_SIZE = 4096;
private static class BulkScorerAndNext {
public final BulkScorer scorer;
public int next = 0;
BulkScorerAndNext(BulkScorer scorer) {
this.scorer = Objects.requireNonNull(scorer);
}
}
// WINDOW_SIZE + 1 to ease iteration on the bit set
private final FixedBitSet windowMatches = new FixedBitSet(WINDOW_SIZE + 1);
private final float[] windowScores = new float[WINDOW_SIZE];
private final PriorityQueue<BulkScorerAndNext> scorers;
private final SimpleScorable topLevelScorable = new SimpleScorable();
DisjunctionMaxBulkScorer(List<BulkScorer> scorers) {
if (scorers.size() < 2) {
throw new IllegalArgumentException();
}
this.scorers =
new PriorityQueue<>(scorers.size()) {
@Override
protected boolean lessThan(BulkScorerAndNext a, BulkScorerAndNext b) {
return a.next < b.next;
}
};
for (BulkScorer scorer : scorers) {
this.scorers.add(new BulkScorerAndNext(scorer));
}
}
@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
BulkScorerAndNext top = scorers.top();
while (top.next < max) {
final int windowMin = Math.max(top.next, min);
final int windowMax = (int) Math.min(max, (long) windowMin + WINDOW_SIZE);
// First compute matches / scores in the window
do {
top.next =
top.scorer.score(
new LeafCollector() {
private Scorable scorer;
@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
if (topLevelScorable.minCompetitiveScore != 0f) {
scorer.setMinCompetitiveScore(topLevelScorable.minCompetitiveScore);
}
}
@Override
public void collect(int doc) throws IOException {
final int delta = doc - windowMin;
windowMatches.set(doc - windowMin);
windowScores[delta] = Math.max(windowScores[delta], scorer.score());
}
},
acceptDocs,
windowMin,
windowMax);
top = scorers.updateTop();
} while (top.next < windowMax);
// Then replay
collector.setScorer(topLevelScorable);
for (int windowDoc = windowMatches.nextSetBit(0);
windowDoc != DocIdSetIterator.NO_MORE_DOCS;
windowDoc = windowMatches.nextSetBit(windowDoc + 1)) {
int doc = windowMin + windowDoc;
topLevelScorable.score = windowScores[windowDoc];
collector.collect(doc);
}
// Finally clean up state
windowMatches.clear();
Arrays.fill(windowScores, 0f);
}
return top.next;
}
@Override
public long cost() {
long cost = 0;
for (BulkScorerAndNext scorer : scorers) {
cost += scorer.scorer.cost();
}
return cost;
}
}

View File

@ -158,6 +158,18 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
return new DisjunctionMaxScorer(tieBreakerMultiplier, scorers, scoreMode);
}
@Override
public BulkScorer bulkScorer() throws IOException {
if (tieBreakerMultiplier == 0f) {
List<BulkScorer> scorers = new ArrayList<>();
for (ScorerSupplier ss : scorerSuppliers) {
scorers.add(ss.bulkScorer());
}
return new DisjunctionMaxBulkScorer(scorers);
}
return super.bulkScorer();
}
@Override
public long cost() {
if (cost == -1) {

View File

@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
/** Simplest implementation of {@link Scorable}, implemented via simple getters and setters. */
final class SimpleScorable extends Scorable {
float score;
float minCompetitiveScore;
/** Sole constructor. */
public SimpleScorable() {}
@Override
public float score() {
return score;
}
/** Set the score. */
public void setScore(float score) {
this.score = score;
}
/** Get the min competitive score. */
public float minCompetitiveScore() {
return minCompetitiveScore;
}
@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
minCompetitiveScore = minScore;
}
}