LUCENE-8439: Disjunction max queries can skip blocks to select the top documents when the total hit count is not required

This commit is contained in:
Jim Ferenczi 2018-08-08 12:34:42 +02:00
parent 38bf976cd4
commit ba9b18f367
9 changed files with 244 additions and 42 deletions

View File

@ -134,6 +134,9 @@ Optimizations
or phrase queries as sub queries, which know how to leverage this information
to run faster. (Adrien Grand)
* LUCENE-8439: Disjunction max queries can skip blocks to select the top documents
when the total hit count is not required. (Jim Ferenczi, Adrien Grand)
======================= Lucene 7.5.0 =======================
API Changes:

View File

@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
/**
* {@link DocIdSetIterator} that skips non-competitive docs by checking
* the max score of the provided {@link Scorer} for the current block.
* Call {@link #setMinCompetitiveScore(float)} in order to give this iterator the ability
* to skip low-scoring documents.
* @lucene.internal
*/
public class BlockMaxDISI extends DocIdSetIterator {
protected final Scorer scorer;
private final DocIdSetIterator in;
private float minScore;
private float maxScore;
private int upTo = -1;
public BlockMaxDISI(DocIdSetIterator iterator, Scorer scorer) {
this.in = iterator;
this.scorer = scorer;
}
@Override
public int docID() {
return in.docID();
}
@Override
public int nextDoc() throws IOException {
return advance(docID()+1);
}
@Override
public int advance(int target) throws IOException {
int doc = advanceImpacts(target);
return in.advance(doc);
}
@Override
public long cost() {
return in.cost();
}
public void setMinCompetitiveScore(float minScore) {
this.minScore = minScore;
}
private void moveToNextBlock(int target) throws IOException {
upTo = scorer.advanceShallow(target);
maxScore = scorer.getMaxScore(upTo);
}
private int advanceImpacts(int target) throws IOException {
if (minScore == -1 || target == NO_MORE_DOCS) {
return target;
}
if (target > upTo) {
moveToNextBlock(target);
}
while (true) {
if (maxScore >= minScore) {
return target;
}
if (upTo == NO_MORE_DOCS) {
return NO_MORE_DOCS;
}
target = upTo + 1;
moveToNextBlock(target);
}
}
}

View File

@ -187,7 +187,7 @@ final class Boolean2ScorerSupplier extends ScorerSupplier {
} else if (scoreMode == ScoreMode.TOP_SCORES) {
return new WANDScorer(weight, optionalScorers);
} else {
return new DisjunctionSumScorer(weight, optionalScorers, scoreMode.needsScores());
return new DisjunctionSumScorer(weight, optionalScorers, scoreMode);
}
}
}

View File

@ -309,7 +309,7 @@ final class BooleanWeight extends Weight {
} else {
Scorer prohibitedScorer = prohibited.size() == 1
? prohibited.get(0)
: new DisjunctionSumScorer(this, prohibited, false);
: new DisjunctionSumScorer(this, prohibited, ScoreMode.COMPLETE_NO_SCORES);
if (prohibitedScorer.twoPhaseIterator() != null) {
// ReqExclBulkScorer can't deal efficiently with two-phased prohibited clauses
return null;

View File

@ -148,7 +148,7 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
// only one sub-scorer in this segment
return scorers.get(0);
} else {
return new DisjunctionMaxScorer(this, tieBreakerMultiplier, scorers, scoreMode.needsScores());
return new DisjunctionMaxScorer(this, tieBreakerMultiplier, scorers, scoreMode);
}
}

View File

@ -21,6 +21,8 @@ import java.util.List;
import org.apache.lucene.util.MathUtil;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/**
* The Scorer for DisjunctionMaxQuery. The union of all documents generated by the the subquery scorers
* is generated in document number order. The score for each document is the maximum of the scores computed
@ -28,9 +30,9 @@ import org.apache.lucene.util.MathUtil;
* for the other subqueries that generate the document.
*/
final class DisjunctionMaxScorer extends DisjunctionScorer {
private final List<Scorer> subScorers;
/* Multiplier applied to non-maximum-scoring subqueries for a document as they are summed into the result. */
private final float tieBreakerMultiplier;
private final float maxScore;
/**
* Creates a new instance of DisjunctionMaxScorer
@ -43,40 +45,13 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
* @param subScorers
* The sub scorers this Scorer should iterate on
*/
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, boolean needsScores) throws IOException {
super(weight, subScorers, needsScores);
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight, subScorers, scoreMode);
this.subScorers = subScorers;
this.tieBreakerMultiplier = tieBreakerMultiplier;
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]");
}
if (needsScores == false) {
this.maxScore = Float.MAX_VALUE;
} else {
float scoreMax = 0;
double otherScoreSum = 0;
for (Scorer scorer : subScorers) {
scorer.advanceShallow(0);
float subScore = scorer.getMaxScore(DocIdSetIterator.NO_MORE_DOCS);
if (subScore >= scoreMax) {
otherScoreSum += scoreMax;
scoreMax = subScore;
} else {
otherScoreSum += subScore;
}
}
if (tieBreakerMultiplier == 0) {
this.maxScore = scoreMax;
} else {
// The error of sums depends on the order in which values are summed up. In
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
otherScoreSum *= (1 + 2 * MathUtil.sumRelativeErrorBound(subScorers.size() - 1));
this.maxScore = (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
}
}
}
@Override
@ -95,8 +70,49 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
}
@Override
public int advanceShallow(int target) throws IOException {
int upTo = NO_MORE_DOCS;
for (Scorer scorer : subScorers) {
if (scorer.docID() <= target) {
upTo = Math.min(scorer.advanceShallow(target), upTo);
} else if (scorer.docID() < NO_MORE_DOCS) {
upTo = Math.min(scorer.docID()-1, upTo);
}
}
return upTo;
}
@Override
public float getMaxScore(int upTo) throws IOException {
return maxScore;
float scoreMax = 0;
double otherScoreSum = 0;
for (Scorer scorer : subScorers) {
if (scorer.docID() <= upTo) {
float subScore = scorer.getMaxScore(upTo);
if (subScore >= scoreMax) {
otherScoreSum += scoreMax;
scoreMax = subScore;
} else {
otherScoreSum += subScore;
}
}
}
if (tieBreakerMultiplier == 0) {
return scoreMax;
} else {
// The error of sums depends on the order in which values are summed up. In
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
otherScoreSum *= (1 + 2 * MathUtil.sumRelativeErrorBound(subScorers.size() - 1));
return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
}
}
@Override
public void setMinCompetitiveScore(float minScore) {
getBlockMaxApprox().setMinCompetitiveScore(minScore);
}
}

View File

@ -32,10 +32,11 @@ abstract class DisjunctionScorer extends Scorer {
private final boolean needsScores;
private final DisiPriorityQueue subScorers;
private final DisjunctionDISIApproximation approximation;
private final DocIdSetIterator approximation;
private final BlockMaxDISI blockMaxApprox;
private final TwoPhase twoPhase;
protected DisjunctionScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) {
protected DisjunctionScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight);
if (subScorers.size() <= 1) {
throw new IllegalArgumentException("There must be at least 2 subScorers");
@ -45,8 +46,17 @@ abstract class DisjunctionScorer extends Scorer {
final DisiWrapper w = new DisiWrapper(scorer);
this.subScorers.add(w);
}
this.needsScores = needsScores;
this.approximation = new DisjunctionDISIApproximation(this.subScorers);
this.needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
if (scoreMode == ScoreMode.TOP_SCORES) {
for (Scorer scorer : subScorers) {
scorer.advanceShallow(0);
}
this.blockMaxApprox = new BlockMaxDISI(new DisjunctionDISIApproximation(this.subScorers), this);
this.approximation = blockMaxApprox;
} else {
this.approximation = new DisjunctionDISIApproximation(this.subScorers);
this.blockMaxApprox = null;
}
boolean hasApproximation = false;
float sumMatchCost = 0;
@ -167,6 +177,10 @@ abstract class DisjunctionScorer extends Scorer {
return subScorers.top().doc;
}
BlockMaxDISI getBlockMaxApprox() {
return blockMaxApprox;
}
DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorers.topList();

View File

@ -28,8 +28,8 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
* @param weight The weight to be used.
* @param subScorers Array of at least two subscorers.
*/
DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) throws IOException {
super(weight, subScorers, needsScores);
DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight, subScorers, scoreMode);
}
@Override

View File

@ -16,18 +16,22 @@
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.io.StringReader;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
@ -522,6 +526,77 @@ public class TestDisjunctionMaxQuery extends LuceneTestCase {
assertEquals(bq.clauses().get(0), new BooleanClause(sub1, BooleanClause.Occur.SHOULD));
assertEquals(bq.clauses().get(1), new BooleanClause(sub2, BooleanClause.Occur.SHOULD));
}
public void testRandomTopDocs() throws Exception {
doTestRandomTopDocs(2, 0.05f, 0.05f);
doTestRandomTopDocs(2, 1.0f, 0.05f);
doTestRandomTopDocs(3, 1.0f, 0.5f, 0.05f);
doTestRandomTopDocs(4, 1.0f, 0.5f, 0.05f, 0f);
doTestRandomTopDocs(4, 1.0f, 0.5f, 0.05f, 0f);
}
private void doTestRandomTopDocs(int numFields, double... freqs) throws IOException {
assert numFields == freqs.length;
Directory dir = newDirectory();
IndexWriterConfig config = new IndexWriterConfig(new StandardAnalyzer());
IndexWriter w = new IndexWriter(dir, config);
int numDocs = atLeast(1000); // make sure some terms have skip data
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
for (int j = 0; j < numFields; j++) {
StringBuilder builder = new StringBuilder();
int numAs = random().nextDouble() < freqs[j] ? 0 : 1 + random().nextInt(5);
for (int k = 0; k < numAs; k++) {
if (builder.length() > 0) {
builder.append(' ');
}
builder.append('a');
}
if (random().nextBoolean()) {
doc.add(new StringField("field", "c", Field.Store.NO));
}
int numOthers = random().nextBoolean() ? 0 : 1 + random().nextInt(5);
for (int k = 0; k < numOthers; k++) {
if (builder.length() > 0) {
builder.append(' ');
}
builder.append(Integer.toString(random().nextInt()));
}
doc.add(new TextField(Integer.toString(j), new StringReader(builder.toString())));
}
w.addDocument(doc);
}
IndexReader reader = DirectoryReader.open(w);
w.close();
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < 4; i++) {
List<Query> clauses = new ArrayList<>();
for (int j = 0; j < numFields; j++) {
if (i % 2 == 1) {
clauses.add(tq(Integer.toString(j), "a"));
} else {
float boost = random().nextBoolean() ? 0 : random().nextFloat();
if (boost > 0) {
clauses.add(tq(Integer.toString(j), "a", boost));
} else {
clauses.add(tq(Integer.toString(j), "a"));
}
}
}
float tieBreaker = random().nextFloat();
Query query = new DisjunctionMaxQuery(clauses, tieBreaker);
CheckHits.checkTopScores(random(), query, searcher);
query = new BooleanQuery.Builder()
.add(new DisjunctionMaxQuery(clauses, tieBreaker), BooleanClause.Occur.MUST)
.add(tq("field", "c"), BooleanClause.Occur.FILTER)
.build();
CheckHits.checkTopScores(random(), query, searcher);
}
reader.close();
dir.close();
}
/** macro */
protected Query tq(String f, String t) {