mirror of https://github.com/apache/lucene.git
LUCENE-8439: Disjunction max queries can skip blocks to select the top documents when the total hit count is not required
This commit is contained in:
parent
38bf976cd4
commit
ba9b18f367
|
@ -134,6 +134,9 @@ Optimizations
|
|||
or phrase queries as sub queries, which know how to leverage this information
|
||||
to run faster. (Adrien Grand)
|
||||
|
||||
* LUCENE-8439: Disjunction max queries can skip blocks to select the top documents
|
||||
when the total hit count is not required. (Jim Ferenczi, Adrien Grand)
|
||||
|
||||
======================= Lucene 7.5.0 =======================
|
||||
|
||||
API Changes:
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* {@link DocIdSetIterator} that skips non-competitive docs by checking
|
||||
* the max score of the provided {@link Scorer} for the current block.
|
||||
* Call {@link #setMinCompetitiveScore(float)} in order to give this iterator the ability
|
||||
* to skip low-scoring documents.
|
||||
* @lucene.internal
|
||||
*/
|
||||
public class BlockMaxDISI extends DocIdSetIterator {
|
||||
protected final Scorer scorer;
|
||||
private final DocIdSetIterator in;
|
||||
private float minScore;
|
||||
private float maxScore;
|
||||
private int upTo = -1;
|
||||
|
||||
public BlockMaxDISI(DocIdSetIterator iterator, Scorer scorer) {
|
||||
this.in = iterator;
|
||||
this.scorer = scorer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return in.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(docID()+1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
int doc = advanceImpacts(target);
|
||||
return in.advance(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return in.cost();
|
||||
}
|
||||
|
||||
public void setMinCompetitiveScore(float minScore) {
|
||||
this.minScore = minScore;
|
||||
}
|
||||
|
||||
private void moveToNextBlock(int target) throws IOException {
|
||||
upTo = scorer.advanceShallow(target);
|
||||
maxScore = scorer.getMaxScore(upTo);
|
||||
}
|
||||
|
||||
private int advanceImpacts(int target) throws IOException {
|
||||
if (minScore == -1 || target == NO_MORE_DOCS) {
|
||||
return target;
|
||||
}
|
||||
|
||||
if (target > upTo) {
|
||||
moveToNextBlock(target);
|
||||
}
|
||||
|
||||
while (true) {
|
||||
if (maxScore >= minScore) {
|
||||
return target;
|
||||
}
|
||||
|
||||
if (upTo == NO_MORE_DOCS) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
target = upTo + 1;
|
||||
|
||||
moveToNextBlock(target);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -187,7 +187,7 @@ final class Boolean2ScorerSupplier extends ScorerSupplier {
|
|||
} else if (scoreMode == ScoreMode.TOP_SCORES) {
|
||||
return new WANDScorer(weight, optionalScorers);
|
||||
} else {
|
||||
return new DisjunctionSumScorer(weight, optionalScorers, scoreMode.needsScores());
|
||||
return new DisjunctionSumScorer(weight, optionalScorers, scoreMode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -309,7 +309,7 @@ final class BooleanWeight extends Weight {
|
|||
} else {
|
||||
Scorer prohibitedScorer = prohibited.size() == 1
|
||||
? prohibited.get(0)
|
||||
: new DisjunctionSumScorer(this, prohibited, false);
|
||||
: new DisjunctionSumScorer(this, prohibited, ScoreMode.COMPLETE_NO_SCORES);
|
||||
if (prohibitedScorer.twoPhaseIterator() != null) {
|
||||
// ReqExclBulkScorer can't deal efficiently with two-phased prohibited clauses
|
||||
return null;
|
||||
|
|
|
@ -148,7 +148,7 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
|
|||
// only one sub-scorer in this segment
|
||||
return scorers.get(0);
|
||||
} else {
|
||||
return new DisjunctionMaxScorer(this, tieBreakerMultiplier, scorers, scoreMode.needsScores());
|
||||
return new DisjunctionMaxScorer(this, tieBreakerMultiplier, scorers, scoreMode);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@ import java.util.List;
|
|||
|
||||
import org.apache.lucene.util.MathUtil;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
/**
|
||||
* The Scorer for DisjunctionMaxQuery. The union of all documents generated by the the subquery scorers
|
||||
* is generated in document number order. The score for each document is the maximum of the scores computed
|
||||
|
@ -28,9 +30,9 @@ import org.apache.lucene.util.MathUtil;
|
|||
* for the other subqueries that generate the document.
|
||||
*/
|
||||
final class DisjunctionMaxScorer extends DisjunctionScorer {
|
||||
private final List<Scorer> subScorers;
|
||||
/* Multiplier applied to non-maximum-scoring subqueries for a document as they are summed into the result. */
|
||||
private final float tieBreakerMultiplier;
|
||||
private final float maxScore;
|
||||
|
||||
/**
|
||||
* Creates a new instance of DisjunctionMaxScorer
|
||||
|
@ -43,40 +45,13 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
|
|||
* @param subScorers
|
||||
* The sub scorers this Scorer should iterate on
|
||||
*/
|
||||
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, boolean needsScores) throws IOException {
|
||||
super(weight, subScorers, needsScores);
|
||||
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
|
||||
super(weight, subScorers, scoreMode);
|
||||
this.subScorers = subScorers;
|
||||
this.tieBreakerMultiplier = tieBreakerMultiplier;
|
||||
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
|
||||
throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]");
|
||||
}
|
||||
|
||||
if (needsScores == false) {
|
||||
this.maxScore = Float.MAX_VALUE;
|
||||
} else {
|
||||
float scoreMax = 0;
|
||||
double otherScoreSum = 0;
|
||||
for (Scorer scorer : subScorers) {
|
||||
scorer.advanceShallow(0);
|
||||
float subScore = scorer.getMaxScore(DocIdSetIterator.NO_MORE_DOCS);
|
||||
if (subScore >= scoreMax) {
|
||||
otherScoreSum += scoreMax;
|
||||
scoreMax = subScore;
|
||||
} else {
|
||||
otherScoreSum += subScore;
|
||||
}
|
||||
}
|
||||
|
||||
if (tieBreakerMultiplier == 0) {
|
||||
this.maxScore = scoreMax;
|
||||
} else {
|
||||
// The error of sums depends on the order in which values are summed up. In
|
||||
// order to avoid this issue, we compute an upper bound of the value that
|
||||
// the sum may take. If the max relative error is b, then it means that two
|
||||
// sums are always within 2*b of each other.
|
||||
otherScoreSum *= (1 + 2 * MathUtil.sumRelativeErrorBound(subScorers.size() - 1));
|
||||
this.maxScore = (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -95,8 +70,49 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
|
|||
return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advanceShallow(int target) throws IOException {
|
||||
int upTo = NO_MORE_DOCS;
|
||||
for (Scorer scorer : subScorers) {
|
||||
if (scorer.docID() <= target) {
|
||||
upTo = Math.min(scorer.advanceShallow(target), upTo);
|
||||
} else if (scorer.docID() < NO_MORE_DOCS) {
|
||||
upTo = Math.min(scorer.docID()-1, upTo);
|
||||
}
|
||||
}
|
||||
return upTo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getMaxScore(int upTo) throws IOException {
|
||||
return maxScore;
|
||||
float scoreMax = 0;
|
||||
double otherScoreSum = 0;
|
||||
for (Scorer scorer : subScorers) {
|
||||
if (scorer.docID() <= upTo) {
|
||||
float subScore = scorer.getMaxScore(upTo);
|
||||
if (subScore >= scoreMax) {
|
||||
otherScoreSum += scoreMax;
|
||||
scoreMax = subScore;
|
||||
} else {
|
||||
otherScoreSum += subScore;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tieBreakerMultiplier == 0) {
|
||||
return scoreMax;
|
||||
} else {
|
||||
// The error of sums depends on the order in which values are summed up. In
|
||||
// order to avoid this issue, we compute an upper bound of the value that
|
||||
// the sum may take. If the max relative error is b, then it means that two
|
||||
// sums are always within 2*b of each other.
|
||||
otherScoreSum *= (1 + 2 * MathUtil.sumRelativeErrorBound(subScorers.size() - 1));
|
||||
return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setMinCompetitiveScore(float minScore) {
|
||||
getBlockMaxApprox().setMinCompetitiveScore(minScore);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,10 +32,11 @@ abstract class DisjunctionScorer extends Scorer {
|
|||
private final boolean needsScores;
|
||||
|
||||
private final DisiPriorityQueue subScorers;
|
||||
private final DisjunctionDISIApproximation approximation;
|
||||
private final DocIdSetIterator approximation;
|
||||
private final BlockMaxDISI blockMaxApprox;
|
||||
private final TwoPhase twoPhase;
|
||||
|
||||
protected DisjunctionScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) {
|
||||
protected DisjunctionScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
|
||||
super(weight);
|
||||
if (subScorers.size() <= 1) {
|
||||
throw new IllegalArgumentException("There must be at least 2 subScorers");
|
||||
|
@ -45,8 +46,17 @@ abstract class DisjunctionScorer extends Scorer {
|
|||
final DisiWrapper w = new DisiWrapper(scorer);
|
||||
this.subScorers.add(w);
|
||||
}
|
||||
this.needsScores = needsScores;
|
||||
this.approximation = new DisjunctionDISIApproximation(this.subScorers);
|
||||
this.needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
|
||||
if (scoreMode == ScoreMode.TOP_SCORES) {
|
||||
for (Scorer scorer : subScorers) {
|
||||
scorer.advanceShallow(0);
|
||||
}
|
||||
this.blockMaxApprox = new BlockMaxDISI(new DisjunctionDISIApproximation(this.subScorers), this);
|
||||
this.approximation = blockMaxApprox;
|
||||
} else {
|
||||
this.approximation = new DisjunctionDISIApproximation(this.subScorers);
|
||||
this.blockMaxApprox = null;
|
||||
}
|
||||
|
||||
boolean hasApproximation = false;
|
||||
float sumMatchCost = 0;
|
||||
|
@ -167,6 +177,10 @@ abstract class DisjunctionScorer extends Scorer {
|
|||
return subScorers.top().doc;
|
||||
}
|
||||
|
||||
BlockMaxDISI getBlockMaxApprox() {
|
||||
return blockMaxApprox;
|
||||
}
|
||||
|
||||
DisiWrapper getSubMatches() throws IOException {
|
||||
if (twoPhase == null) {
|
||||
return subScorers.topList();
|
||||
|
|
|
@ -28,8 +28,8 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
|
|||
* @param weight The weight to be used.
|
||||
* @param subScorers Array of at least two subscorers.
|
||||
*/
|
||||
DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) throws IOException {
|
||||
super(weight, subScorers, needsScores);
|
||||
DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
|
||||
super(weight, subScorers, scoreMode);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,18 +16,22 @@
|
|||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.StringReader;
|
||||
import java.text.DecimalFormat;
|
||||
import java.text.DecimalFormatSymbols;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.analysis.standard.StandardAnalyzer;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.document.TextField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
|
@ -522,6 +526,77 @@ public class TestDisjunctionMaxQuery extends LuceneTestCase {
|
|||
assertEquals(bq.clauses().get(0), new BooleanClause(sub1, BooleanClause.Occur.SHOULD));
|
||||
assertEquals(bq.clauses().get(1), new BooleanClause(sub2, BooleanClause.Occur.SHOULD));
|
||||
}
|
||||
|
||||
public void testRandomTopDocs() throws Exception {
|
||||
doTestRandomTopDocs(2, 0.05f, 0.05f);
|
||||
doTestRandomTopDocs(2, 1.0f, 0.05f);
|
||||
doTestRandomTopDocs(3, 1.0f, 0.5f, 0.05f);
|
||||
doTestRandomTopDocs(4, 1.0f, 0.5f, 0.05f, 0f);
|
||||
doTestRandomTopDocs(4, 1.0f, 0.5f, 0.05f, 0f);
|
||||
}
|
||||
|
||||
private void doTestRandomTopDocs(int numFields, double... freqs) throws IOException {
|
||||
assert numFields == freqs.length;
|
||||
Directory dir = newDirectory();
|
||||
IndexWriterConfig config = new IndexWriterConfig(new StandardAnalyzer());
|
||||
IndexWriter w = new IndexWriter(dir, config);
|
||||
|
||||
int numDocs = atLeast(1000); // make sure some terms have skip data
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
Document doc = new Document();
|
||||
for (int j = 0; j < numFields; j++) {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
int numAs = random().nextDouble() < freqs[j] ? 0 : 1 + random().nextInt(5);
|
||||
for (int k = 0; k < numAs; k++) {
|
||||
if (builder.length() > 0) {
|
||||
builder.append(' ');
|
||||
}
|
||||
builder.append('a');
|
||||
}
|
||||
if (random().nextBoolean()) {
|
||||
doc.add(new StringField("field", "c", Field.Store.NO));
|
||||
}
|
||||
int numOthers = random().nextBoolean() ? 0 : 1 + random().nextInt(5);
|
||||
for (int k = 0; k < numOthers; k++) {
|
||||
if (builder.length() > 0) {
|
||||
builder.append(' ');
|
||||
}
|
||||
builder.append(Integer.toString(random().nextInt()));
|
||||
}
|
||||
doc.add(new TextField(Integer.toString(j), new StringReader(builder.toString())));
|
||||
}
|
||||
w.addDocument(doc);
|
||||
}
|
||||
IndexReader reader = DirectoryReader.open(w);
|
||||
w.close();
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
List<Query> clauses = new ArrayList<>();
|
||||
for (int j = 0; j < numFields; j++) {
|
||||
if (i % 2 == 1) {
|
||||
clauses.add(tq(Integer.toString(j), "a"));
|
||||
} else {
|
||||
float boost = random().nextBoolean() ? 0 : random().nextFloat();
|
||||
if (boost > 0) {
|
||||
clauses.add(tq(Integer.toString(j), "a", boost));
|
||||
} else {
|
||||
clauses.add(tq(Integer.toString(j), "a"));
|
||||
}
|
||||
}
|
||||
}
|
||||
float tieBreaker = random().nextFloat();
|
||||
Query query = new DisjunctionMaxQuery(clauses, tieBreaker);
|
||||
CheckHits.checkTopScores(random(), query, searcher);
|
||||
|
||||
query = new BooleanQuery.Builder()
|
||||
.add(new DisjunctionMaxQuery(clauses, tieBreaker), BooleanClause.Occur.MUST)
|
||||
.add(tq("field", "c"), BooleanClause.Occur.FILTER)
|
||||
.build();
|
||||
CheckHits.checkTopScores(random(), query, searcher);
|
||||
}
|
||||
reader.close();
|
||||
dir.close();
|
||||
}
|
||||
|
||||
/** macro */
|
||||
protected Query tq(String f, String t) {
|
||||
|
|
Loading…
Reference in New Issue