LUCENE-8439: Disjunction max queries can skip blocks to select the top documents when the total hit count is not required

This commit is contained in:
Jim Ferenczi 2018-08-08 12:34:42 +02:00
parent 38bf976cd4
commit ba9b18f367
9 changed files with 244 additions and 42 deletions

View File

@ -134,6 +134,9 @@ Optimizations
or phrase queries as sub queries, which know how to leverage this information or phrase queries as sub queries, which know how to leverage this information
to run faster. (Adrien Grand) to run faster. (Adrien Grand)
* LUCENE-8439: Disjunction max queries can skip blocks to select the top documents
when the total hit count is not required. (Jim Ferenczi, Adrien Grand)
======================= Lucene 7.5.0 ======================= ======================= Lucene 7.5.0 =======================
API Changes: API Changes:

View File

@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
/**
* {@link DocIdSetIterator} that skips non-competitive docs by checking
* the max score of the provided {@link Scorer} for the current block.
* Call {@link #setMinCompetitiveScore(float)} in order to give this iterator the ability
* to skip low-scoring documents.
* @lucene.internal
*/
public class BlockMaxDISI extends DocIdSetIterator {
protected final Scorer scorer;
private final DocIdSetIterator in;
private float minScore;
private float maxScore;
private int upTo = -1;
public BlockMaxDISI(DocIdSetIterator iterator, Scorer scorer) {
this.in = iterator;
this.scorer = scorer;
}
@Override
public int docID() {
return in.docID();
}
@Override
public int nextDoc() throws IOException {
return advance(docID()+1);
}
@Override
public int advance(int target) throws IOException {
int doc = advanceImpacts(target);
return in.advance(doc);
}
@Override
public long cost() {
return in.cost();
}
public void setMinCompetitiveScore(float minScore) {
this.minScore = minScore;
}
private void moveToNextBlock(int target) throws IOException {
upTo = scorer.advanceShallow(target);
maxScore = scorer.getMaxScore(upTo);
}
private int advanceImpacts(int target) throws IOException {
if (minScore == -1 || target == NO_MORE_DOCS) {
return target;
}
if (target > upTo) {
moveToNextBlock(target);
}
while (true) {
if (maxScore >= minScore) {
return target;
}
if (upTo == NO_MORE_DOCS) {
return NO_MORE_DOCS;
}
target = upTo + 1;
moveToNextBlock(target);
}
}
}

View File

@ -187,7 +187,7 @@ final class Boolean2ScorerSupplier extends ScorerSupplier {
} else if (scoreMode == ScoreMode.TOP_SCORES) { } else if (scoreMode == ScoreMode.TOP_SCORES) {
return new WANDScorer(weight, optionalScorers); return new WANDScorer(weight, optionalScorers);
} else { } else {
return new DisjunctionSumScorer(weight, optionalScorers, scoreMode.needsScores()); return new DisjunctionSumScorer(weight, optionalScorers, scoreMode);
} }
} }
} }

View File

@ -309,7 +309,7 @@ final class BooleanWeight extends Weight {
} else { } else {
Scorer prohibitedScorer = prohibited.size() == 1 Scorer prohibitedScorer = prohibited.size() == 1
? prohibited.get(0) ? prohibited.get(0)
: new DisjunctionSumScorer(this, prohibited, false); : new DisjunctionSumScorer(this, prohibited, ScoreMode.COMPLETE_NO_SCORES);
if (prohibitedScorer.twoPhaseIterator() != null) { if (prohibitedScorer.twoPhaseIterator() != null) {
// ReqExclBulkScorer can't deal efficiently with two-phased prohibited clauses // ReqExclBulkScorer can't deal efficiently with two-phased prohibited clauses
return null; return null;

View File

@ -148,7 +148,7 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
// only one sub-scorer in this segment // only one sub-scorer in this segment
return scorers.get(0); return scorers.get(0);
} else { } else {
return new DisjunctionMaxScorer(this, tieBreakerMultiplier, scorers, scoreMode.needsScores()); return new DisjunctionMaxScorer(this, tieBreakerMultiplier, scorers, scoreMode);
} }
} }

View File

@ -21,6 +21,8 @@ import java.util.List;
import org.apache.lucene.util.MathUtil; import org.apache.lucene.util.MathUtil;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/** /**
* The Scorer for DisjunctionMaxQuery. The union of all documents generated by the the subquery scorers * The Scorer for DisjunctionMaxQuery. The union of all documents generated by the the subquery scorers
* is generated in document number order. The score for each document is the maximum of the scores computed * is generated in document number order. The score for each document is the maximum of the scores computed
@ -28,9 +30,9 @@ import org.apache.lucene.util.MathUtil;
* for the other subqueries that generate the document. * for the other subqueries that generate the document.
*/ */
final class DisjunctionMaxScorer extends DisjunctionScorer { final class DisjunctionMaxScorer extends DisjunctionScorer {
private final List<Scorer> subScorers;
/* Multiplier applied to non-maximum-scoring subqueries for a document as they are summed into the result. */ /* Multiplier applied to non-maximum-scoring subqueries for a document as they are summed into the result. */
private final float tieBreakerMultiplier; private final float tieBreakerMultiplier;
private final float maxScore;
/** /**
* Creates a new instance of DisjunctionMaxScorer * Creates a new instance of DisjunctionMaxScorer
@ -43,40 +45,13 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
* @param subScorers * @param subScorers
* The sub scorers this Scorer should iterate on * The sub scorers this Scorer should iterate on
*/ */
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, boolean needsScores) throws IOException { DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight, subScorers, needsScores); super(weight, subScorers, scoreMode);
this.subScorers = subScorers;
this.tieBreakerMultiplier = tieBreakerMultiplier; this.tieBreakerMultiplier = tieBreakerMultiplier;
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) { if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]"); throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]");
} }
if (needsScores == false) {
this.maxScore = Float.MAX_VALUE;
} else {
float scoreMax = 0;
double otherScoreSum = 0;
for (Scorer scorer : subScorers) {
scorer.advanceShallow(0);
float subScore = scorer.getMaxScore(DocIdSetIterator.NO_MORE_DOCS);
if (subScore >= scoreMax) {
otherScoreSum += scoreMax;
scoreMax = subScore;
} else {
otherScoreSum += subScore;
}
}
if (tieBreakerMultiplier == 0) {
this.maxScore = scoreMax;
} else {
// The error of sums depends on the order in which values are summed up. In
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
otherScoreSum *= (1 + 2 * MathUtil.sumRelativeErrorBound(subScorers.size() - 1));
this.maxScore = (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
}
}
} }
@Override @Override
@ -95,8 +70,49 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier); return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
} }
@Override
public int advanceShallow(int target) throws IOException {
int upTo = NO_MORE_DOCS;
for (Scorer scorer : subScorers) {
if (scorer.docID() <= target) {
upTo = Math.min(scorer.advanceShallow(target), upTo);
} else if (scorer.docID() < NO_MORE_DOCS) {
upTo = Math.min(scorer.docID()-1, upTo);
}
}
return upTo;
}
@Override @Override
public float getMaxScore(int upTo) throws IOException { public float getMaxScore(int upTo) throws IOException {
return maxScore; float scoreMax = 0;
double otherScoreSum = 0;
for (Scorer scorer : subScorers) {
if (scorer.docID() <= upTo) {
float subScore = scorer.getMaxScore(upTo);
if (subScore >= scoreMax) {
otherScoreSum += scoreMax;
scoreMax = subScore;
} else {
otherScoreSum += subScore;
}
}
}
if (tieBreakerMultiplier == 0) {
return scoreMax;
} else {
// The error of sums depends on the order in which values are summed up. In
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
otherScoreSum *= (1 + 2 * MathUtil.sumRelativeErrorBound(subScorers.size() - 1));
return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
}
}
@Override
public void setMinCompetitiveScore(float minScore) {
getBlockMaxApprox().setMinCompetitiveScore(minScore);
} }
} }

View File

@ -32,10 +32,11 @@ abstract class DisjunctionScorer extends Scorer {
private final boolean needsScores; private final boolean needsScores;
private final DisiPriorityQueue subScorers; private final DisiPriorityQueue subScorers;
private final DisjunctionDISIApproximation approximation; private final DocIdSetIterator approximation;
private final BlockMaxDISI blockMaxApprox;
private final TwoPhase twoPhase; private final TwoPhase twoPhase;
protected DisjunctionScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) { protected DisjunctionScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight); super(weight);
if (subScorers.size() <= 1) { if (subScorers.size() <= 1) {
throw new IllegalArgumentException("There must be at least 2 subScorers"); throw new IllegalArgumentException("There must be at least 2 subScorers");
@ -45,8 +46,17 @@ abstract class DisjunctionScorer extends Scorer {
final DisiWrapper w = new DisiWrapper(scorer); final DisiWrapper w = new DisiWrapper(scorer);
this.subScorers.add(w); this.subScorers.add(w);
} }
this.needsScores = needsScores; this.needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
this.approximation = new DisjunctionDISIApproximation(this.subScorers); if (scoreMode == ScoreMode.TOP_SCORES) {
for (Scorer scorer : subScorers) {
scorer.advanceShallow(0);
}
this.blockMaxApprox = new BlockMaxDISI(new DisjunctionDISIApproximation(this.subScorers), this);
this.approximation = blockMaxApprox;
} else {
this.approximation = new DisjunctionDISIApproximation(this.subScorers);
this.blockMaxApprox = null;
}
boolean hasApproximation = false; boolean hasApproximation = false;
float sumMatchCost = 0; float sumMatchCost = 0;
@ -167,6 +177,10 @@ abstract class DisjunctionScorer extends Scorer {
return subScorers.top().doc; return subScorers.top().doc;
} }
BlockMaxDISI getBlockMaxApprox() {
return blockMaxApprox;
}
DisiWrapper getSubMatches() throws IOException { DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) { if (twoPhase == null) {
return subScorers.topList(); return subScorers.topList();

View File

@ -28,8 +28,8 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
* @param weight The weight to be used. * @param weight The weight to be used.
* @param subScorers Array of at least two subscorers. * @param subScorers Array of at least two subscorers.
*/ */
DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) throws IOException { DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight, subScorers, needsScores); super(weight, subScorers, scoreMode);
} }
@Override @Override

View File

@ -16,18 +16,22 @@
*/ */
package org.apache.lucene.search; package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.io.StringReader;
import java.text.DecimalFormat; import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols; import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.Locale; import java.util.Locale;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType; import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField; import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
@ -522,6 +526,77 @@ public class TestDisjunctionMaxQuery extends LuceneTestCase {
assertEquals(bq.clauses().get(0), new BooleanClause(sub1, BooleanClause.Occur.SHOULD)); assertEquals(bq.clauses().get(0), new BooleanClause(sub1, BooleanClause.Occur.SHOULD));
assertEquals(bq.clauses().get(1), new BooleanClause(sub2, BooleanClause.Occur.SHOULD)); assertEquals(bq.clauses().get(1), new BooleanClause(sub2, BooleanClause.Occur.SHOULD));
} }
public void testRandomTopDocs() throws Exception {
doTestRandomTopDocs(2, 0.05f, 0.05f);
doTestRandomTopDocs(2, 1.0f, 0.05f);
doTestRandomTopDocs(3, 1.0f, 0.5f, 0.05f);
doTestRandomTopDocs(4, 1.0f, 0.5f, 0.05f, 0f);
doTestRandomTopDocs(4, 1.0f, 0.5f, 0.05f, 0f);
}
private void doTestRandomTopDocs(int numFields, double... freqs) throws IOException {
assert numFields == freqs.length;
Directory dir = newDirectory();
IndexWriterConfig config = new IndexWriterConfig(new StandardAnalyzer());
IndexWriter w = new IndexWriter(dir, config);
int numDocs = atLeast(1000); // make sure some terms have skip data
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
for (int j = 0; j < numFields; j++) {
StringBuilder builder = new StringBuilder();
int numAs = random().nextDouble() < freqs[j] ? 0 : 1 + random().nextInt(5);
for (int k = 0; k < numAs; k++) {
if (builder.length() > 0) {
builder.append(' ');
}
builder.append('a');
}
if (random().nextBoolean()) {
doc.add(new StringField("field", "c", Field.Store.NO));
}
int numOthers = random().nextBoolean() ? 0 : 1 + random().nextInt(5);
for (int k = 0; k < numOthers; k++) {
if (builder.length() > 0) {
builder.append(' ');
}
builder.append(Integer.toString(random().nextInt()));
}
doc.add(new TextField(Integer.toString(j), new StringReader(builder.toString())));
}
w.addDocument(doc);
}
IndexReader reader = DirectoryReader.open(w);
w.close();
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < 4; i++) {
List<Query> clauses = new ArrayList<>();
for (int j = 0; j < numFields; j++) {
if (i % 2 == 1) {
clauses.add(tq(Integer.toString(j), "a"));
} else {
float boost = random().nextBoolean() ? 0 : random().nextFloat();
if (boost > 0) {
clauses.add(tq(Integer.toString(j), "a", boost));
} else {
clauses.add(tq(Integer.toString(j), "a"));
}
}
}
float tieBreaker = random().nextFloat();
Query query = new DisjunctionMaxQuery(clauses, tieBreaker);
CheckHits.checkTopScores(random(), query, searcher);
query = new BooleanQuery.Builder()
.add(new DisjunctionMaxQuery(clauses, tieBreaker), BooleanClause.Occur.MUST)
.add(tq("field", "c"), BooleanClause.Occur.FILTER)
.build();
CheckHits.checkTopScores(random(), query, searcher);
}
reader.close();
dir.close();
}
/** macro */ /** macro */
protected Query tq(String f, String t) { protected Query tq(String f, String t) {