LUCENE-8204: Boolean queries with a mix of required and optional clauses are now faster if the total hit count is not required

This commit is contained in:
Jim Ferenczi 2018-08-08 15:49:58 +02:00
parent 6845bbff55
commit 49e3cca77f
5 changed files with 352 additions and 81 deletions

View File

@ -135,7 +135,10 @@ Optimizations
to run faster. (Adrien Grand)
* LUCENE-8439: Disjunction max queries can skip blocks to select the top documents
when the total hit count is not required. (Jim Ferenczi, Adrien Grand)
if the total hit count is not required. (Jim Ferenczi, Adrien Grand)
* LUCENE-8204: Boolean queries with a mix of required and optional clauses are
now faster if the total hit count is not required. (Jim Ferenczi, Adrien Grand)
======================= Lucene 7.5.0 =======================

View File

@ -111,7 +111,7 @@ final class Boolean2ScorerSupplier extends ScorerSupplier {
assert scoreMode.needsScores();
return new ReqOptSumScorer(
excl(req(subs.get(Occur.FILTER), subs.get(Occur.MUST), leadCost), subs.get(Occur.MUST_NOT), leadCost),
opt(subs.get(Occur.SHOULD), minShouldMatch, scoreMode, leadCost));
opt(subs.get(Occur.SHOULD), minShouldMatch, scoreMode, leadCost), scoreMode);
}
}

View File

@ -18,108 +18,160 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/** A Scorer for queries with a required part and an optional part.
* Delays skipTo() on the optional part until a score() is needed.
*/
class ReqOptSumScorer extends Scorer {
/** The scorers passed from the constructor.
* These are set to null as soon as their next() or skipTo() returns false.
*/
private final Scorer reqScorer;
private final Scorer optScorer;
private final float reqMaxScore;
private final DocIdSetIterator reqApproximation;
private final DocIdSetIterator optApproximation;
private final TwoPhaseIterator optTwoPhase;
private boolean optIsRequired;
private final DocIdSetIterator approximation;
private final TwoPhaseIterator twoPhase;
final MaxScoreSumPropagator maxScorePropagator;
/** Construct a <code>ReqOptScorer</code>.
private float minScore = 0;
private float reqMaxScore;
private boolean optIsRequired;
/**
* Construct a <code>ReqOptScorer</code>.
*
* @param reqScorer The required scorer. This must match.
* @param optScorer The optional scorer. This is used for scoring only.
* @param scoreMode How the produced scorers will be consumed.
*/
public ReqOptSumScorer(
Scorer reqScorer,
Scorer optScorer) throws IOException
{
public ReqOptSumScorer(Scorer reqScorer, Scorer optScorer, ScoreMode scoreMode) throws IOException {
super(reqScorer.weight);
assert reqScorer != null;
assert optScorer != null;
this.reqScorer = reqScorer;
this.optScorer = optScorer;
reqScorer.advanceShallow(0);
this.reqMaxScore = reqScorer.getMaxScore(DocIdSetIterator.NO_MORE_DOCS);
this.maxScorePropagator = new MaxScoreSumPropagator(Arrays.asList(reqScorer, optScorer));
final TwoPhaseIterator reqTwoPhase = reqScorer.twoPhaseIterator();
this.optTwoPhase = optScorer.twoPhaseIterator();
final DocIdSetIterator reqApproximation;
if (reqTwoPhase == null) {
reqApproximation = reqScorer.iterator();
} else {
reqApproximation= reqTwoPhase.approximation();
reqApproximation = reqTwoPhase.approximation();
}
if (optTwoPhase == null) {
optApproximation = optScorer.iterator();
} else {
optApproximation= optTwoPhase.approximation();
optApproximation = optTwoPhase.approximation();
}
if (scoreMode != ScoreMode.TOP_SCORES) {
approximation = reqApproximation;
this.reqMaxScore = Float.POSITIVE_INFINITY;
} else {
reqScorer.advanceShallow(0);
optScorer.advanceShallow(0);
this.reqMaxScore = reqScorer.getMaxScore(NO_MORE_DOCS);
this.approximation = new DocIdSetIterator() {
int upTo = -1;
float maxScore;
approximation = new DocIdSetIterator() {
private void moveToNextBlock(int target) throws IOException {
upTo = advanceShallow(target);
float reqMaxScoreBlock = reqScorer.getMaxScore(upTo);
maxScore = getMaxScore(upTo);
private int nextCommonDoc(int reqDoc) throws IOException {
int optDoc = optApproximation.docID();
if (optDoc > reqDoc) {
reqDoc = reqApproximation.advance(optDoc);
// Potentially move to a conjunction
optIsRequired = reqMaxScoreBlock < minScore;
}
while (true) { // invariant: reqDoc >= optDoc
if (reqDoc == optDoc) {
return reqDoc;
private int advanceImpacts(int target) throws IOException {
if (target > upTo) {
moveToNextBlock(target);
}
optDoc = optApproximation.advance(reqDoc);
if (optDoc == reqDoc) {
return reqDoc;
while (true) {
if (maxScore >= minScore) {
return target;
}
if (upTo == NO_MORE_DOCS) {
return NO_MORE_DOCS;
}
target = upTo + 1;
moveToNextBlock(target);
}
reqDoc = reqApproximation.advance(optDoc);
}
}
@Override
public int nextDoc() throws IOException {
int doc = reqApproximation.nextDoc();
if (optIsRequired) {
doc = nextCommonDoc(doc);
@Override
public int nextDoc() throws IOException {
return advanceInternal(reqApproximation.docID()+1);
}
return doc;
}
@Override
public int advance(int target) throws IOException {
int doc = reqApproximation.advance(target);
if (optIsRequired) {
doc = nextCommonDoc(doc);
@Override
public int advance(int target) throws IOException {
return advanceInternal(target);
}
return doc;
}
@Override
public int docID() {
return reqApproximation.docID();
}
private int advanceInternal(int target) throws IOException {
if (target == NO_MORE_DOCS) {
reqApproximation.advance(target);
return NO_MORE_DOCS;
}
int reqDoc = target;
advanceHead: for (;;) {
if (minScore != 0) {
reqDoc = advanceImpacts(reqDoc);
}
if (reqApproximation.docID() < reqDoc) {
reqDoc = reqApproximation.advance(reqDoc);
}
if (reqDoc == NO_MORE_DOCS || optIsRequired == false) {
return reqDoc;
}
@Override
public long cost() {
return reqApproximation.cost();
}
int upperBound = reqMaxScore < minScore ? NO_MORE_DOCS : upTo;
if (reqDoc > upperBound) {
continue;
}
};
// Find the next common doc within the current block
for (;;) { // invariant: reqDoc >= optDoc
int optDoc = optApproximation.docID();
if (optDoc < reqDoc) {
optDoc = optApproximation.advance(reqDoc);
}
if (optDoc > upperBound) {
reqDoc = upperBound + 1;
continue advanceHead;
}
if (optDoc != reqDoc) {
reqDoc = reqApproximation.advance(optDoc);
if (reqDoc > upperBound) {
continue advanceHead;
}
}
if (reqDoc == NO_MORE_DOCS || optDoc == reqDoc) {
return reqDoc;
}
}
}
}
@Override
public int docID() {
return reqApproximation.docID();
}
@Override
public long cost() {
return reqApproximation.cost();
}
};
}
if (reqTwoPhase == null && optTwoPhase == null) {
this.twoPhase = null;
@ -212,10 +264,13 @@ class ReqOptSumScorer extends Scorer {
@Override
public int advanceShallow(int target) throws IOException {
if (optScorer.docID() < target) {
optScorer.advanceShallow(target);
int upTo = reqScorer.advanceShallow(target);
if (optScorer.docID() <= target) {
upTo = Math.min(upTo, optScorer.advanceShallow(target));
} else if (optScorer.docID() != NO_MORE_DOCS) {
upTo = Math.min(upTo, optScorer.docID() - 1);
}
return reqScorer.advanceShallow(target);
return upTo;
}
@Override
@ -229,8 +284,9 @@ class ReqOptSumScorer extends Scorer {
@Override
public void setMinCompetitiveScore(float minScore) {
this.minScore = minScore;
// Potentially move to a conjunction
if (optIsRequired == false && minScore > reqMaxScore) {
if (reqMaxScore < minScore) {
optIsRequired = true;
}
}
@ -242,5 +298,4 @@ class ReqOptSumScorer extends Scorer {
children.add(new ChildScorer(optScorer, "SHOULD"));
return children;
}
}

View File

@ -17,11 +17,20 @@
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
@ -97,6 +106,120 @@ public class TestReqOptSumScorer extends LuceneTestCase {
dir.close();
}
public void testMaxBlock() throws IOException {
Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()));
FieldType ft = new FieldType();
ft.setIndexOptions(IndexOptions.DOCS_AND_FREQS);
ft.setTokenized(true);
ft.freeze();
for (int i = 0; i < 1024; i++) {
// create documents with an increasing number of As and one B
Document doc = new Document();
doc.add(new Field("foo", new TermFreqTokenStream("a", i+1), ft));
if (random().nextFloat() < 0.5f) {
doc.add(new Field("foo", new TermFreqTokenStream("b", 1), ft));
}
w.addDocument(doc);
}
w.forceMerge(1);
w.close();
IndexReader reader = DirectoryReader.open(dir);
IndexSearcher searcher = newSearcher(reader);
searcher.setSimilarity(new TestSimilarity.SimpleSimilarity());
// freq == score
// searcher.setSimilarity(new TestSimilarity.SimpleSimilarity());
final Query reqQ = new TermQuery(new Term("foo", "a"));
final Query optQ = new TermQuery(new Term("foo", "b"));
final Query boolQ = new BooleanQuery.Builder()
.add(reqQ, Occur.MUST)
.add(optQ, Occur.SHOULD)
.build();
Scorer actual = reqOptScorer(searcher, reqQ, optQ, true);
Scorer expected = searcher
.createWeight(boolQ, ScoreMode.COMPLETE, 1)
.scorer(searcher.getIndexReader().leaves().get(0));
actual.setMinCompetitiveScore(Math.nextUp(1));
// Checks that all blocks are fully visited
for (int i = 0; i < 1024; i++) {
assertEquals(i, actual.iterator().nextDoc());
assertEquals(i, expected.iterator().nextDoc());
assertEquals(actual.score(),expected.score(), 0);
}
reader.close();
dir.close();
}
public void testMaxScoreSegment() throws IOException {
Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()));
for (String[] values : Arrays.asList(
new String[]{ "A" }, // 0
new String[]{ "A" }, // 1
new String[]{ }, // 2
new String[]{ "A", "B" }, // 3
new String[]{ "A" }, // 4
new String[]{ "B" }, // 5
new String[]{ "A", "B" }, // 6
new String[]{ "B" } // 7
)) {
Document doc = new Document();
for (String value : values) {
doc.add(new StringField("foo", value, Store.NO));
}
w.addDocument(doc);
}
w.forceMerge(1);
w.close();
IndexReader reader = DirectoryReader.open(dir);
IndexSearcher searcher = newSearcher(reader);
final Query reqQ = new ConstantScoreQuery(new TermQuery(new Term("foo", "A")));
final Query optQ = new ConstantScoreQuery(new TermQuery(new Term("foo", "B")));
Scorer scorer = reqOptScorer(searcher, reqQ, optQ, false);
assertEquals(0, scorer.iterator().nextDoc());
assertEquals(1, scorer.score(), 0);
assertEquals(1, scorer.iterator().nextDoc());
assertEquals(1, scorer.score(), 0);
assertEquals(3, scorer.iterator().nextDoc());
assertEquals(2, scorer.score(), 0);
assertEquals(4, scorer.iterator().nextDoc());
assertEquals(1, scorer.score(), 0);
assertEquals(6, scorer.iterator().nextDoc());
assertEquals(2, scorer.score(), 0);
assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
scorer = reqOptScorer(searcher, reqQ, optQ, false);
scorer.setMinCompetitiveScore(Math.nextDown(1f));
assertEquals(0, scorer.iterator().nextDoc());
assertEquals(1, scorer.score(), 0);
assertEquals(1, scorer.iterator().nextDoc());
assertEquals(1, scorer.score(), 0);
assertEquals(3, scorer.iterator().nextDoc());
assertEquals(2, scorer.score(), 0);
assertEquals(4, scorer.iterator().nextDoc());
assertEquals(1, scorer.score(), 0);
assertEquals(6, scorer.iterator().nextDoc());
assertEquals(2, scorer.score(), 0);
assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
scorer = reqOptScorer(searcher, reqQ, optQ, false);
scorer.setMinCompetitiveScore(Math.nextUp(1f));
assertEquals(3, scorer.iterator().nextDoc());
assertEquals(2, scorer.score(), 0);
assertEquals(6, scorer.iterator().nextDoc());
assertEquals(2, scorer.score(), 0);
assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
scorer = reqOptScorer(searcher, reqQ, optQ, true);
scorer.setMinCompetitiveScore(Math.nextUp(2f));
assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
reader.close();
dir.close();
}
public void testRandomFrequentOpt() throws IOException {
doTestRandom(0.5);
}
@ -150,41 +273,125 @@ public class TestReqOptSumScorer extends LuceneTestCase {
searcher.search(query, coll);
ScoreDoc[] expectedFiltered = coll.topDocs().scoreDocs;
for (int i = 0; i < 4; ++i) {
Query must = mustTerm;
if (i % 2 == 1) {
must = new RandomApproximationQuery(must, random());
}
Query should = shouldTerm;
if (i >= 2) {
should = new RandomApproximationQuery(should, random());
}
query = new BooleanQuery.Builder()
.add(must, Occur.MUST)
.add(should, Occur.SHOULD)
CheckHits.checkTopScores(random(), query, searcher);
{
Query q = new BooleanQuery.Builder()
.add(new RandomApproximationQuery(mustTerm, random()), Occur.MUST)
.add(shouldTerm, Occur.SHOULD)
.build();
coll = TopScoreDocCollector.create(10, null, 1);
searcher.search(query, coll);
searcher.search(q, coll);
ScoreDoc[] actual = coll.topDocs().scoreDocs;
CheckHits.checkEqual(query, expected, actual);
q = new BooleanQuery.Builder()
.add(mustTerm, Occur.MUST)
.add(new RandomApproximationQuery(shouldTerm, random()), Occur.SHOULD)
.build();
coll = TopScoreDocCollector.create(10, null, 1);
searcher.search(q, coll);
actual = coll.topDocs().scoreDocs;
CheckHits.checkEqual(q, expected, actual);
q = new BooleanQuery.Builder()
.add(new RandomApproximationQuery(mustTerm, random()), Occur.MUST)
.add(new RandomApproximationQuery(shouldTerm, random()), Occur.SHOULD)
.build();
coll = TopScoreDocCollector.create(10, null, 1);
searcher.search(q, coll);
actual = coll.topDocs().scoreDocs;
CheckHits.checkEqual(q, expected, actual);
}
{
Query nestedQ = new BooleanQuery.Builder()
.add(query, Occur.MUST)
.add(new TermQuery(new Term("f", "C")), Occur.FILTER)
.build();
CheckHits.checkTopScores(random(), nestedQ, searcher);
query = new BooleanQuery.Builder()
.add(query, Occur.MUST)
.add(new RandomApproximationQuery(new TermQuery(new Term("f", "C")), random()), Occur.FILTER)
.build();
coll = TopScoreDocCollector.create(10, null, 1);
searcher.search(query, coll);
searcher.search(nestedQ, coll);
ScoreDoc[] actualFiltered = coll.topDocs().scoreDocs;
CheckHits.checkEqual(nestedQ, expectedFiltered, actualFiltered);
}
CheckHits.checkEqual(query, expectedFiltered, actualFiltered);
{
query = new BooleanQuery.Builder()
.add(query, Occur.MUST)
.add(new TermQuery(new Term("f", "C")), Occur.SHOULD)
.build();
CheckHits.checkTopScores(random(), query, searcher);
query = new BooleanQuery.Builder()
.add(new TermQuery(new Term("f", "C")), Occur.MUST)
.add(query, Occur.SHOULD)
.build();
CheckHits.checkTopScores(random(), query, searcher);
}
r.close();
dir.close();
}
private static Scorer reqOptScorer(IndexSearcher searcher, Query reqQ, Query optQ, boolean withBlockScore) throws IOException {
Scorer reqScorer = searcher
.createWeight(reqQ, ScoreMode.TOP_SCORES, 1)
.scorer(searcher.getIndexReader().leaves().get(0));
Scorer optScorer = searcher
.createWeight(optQ, ScoreMode.TOP_SCORES, 1)
.scorer(searcher.getIndexReader().leaves().get(0));
if (withBlockScore) {
return new ReqOptSumScorer(reqScorer, optScorer, ScoreMode.TOP_SCORES);
} else {
return new ReqOptSumScorer(reqScorer, optScorer, ScoreMode.TOP_SCORES) {
@Override
public float getMaxScore(int upTo) {
return Float.POSITIVE_INFINITY;
}
};
}
}
private static class TermFreqTokenStream extends TokenStream {
private final String term;
private final int termFreq;
private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
private final TermFrequencyAttribute termFreqAtt = addAttribute(TermFrequencyAttribute.class);
private boolean finish;
public TermFreqTokenStream(String term, int termFreq) {
this.term = term;
this.termFreq = termFreq;
}
@Override
public boolean incrementToken() {
if (finish) {
return false;
}
clearAttributes();
termAtt.append(term);
termFreqAtt.setTermFrequency(termFreq);
finish = true;
return true;
}
@Override
public void reset() {
finish = false;
}
}
}

View File

@ -110,6 +110,12 @@ public class RandomApproximationQuery extends Query {
@Override
public int advanceShallow(int target) throws IOException {
if (scorer.docID() > target && twoPhaseView.approximation.docID() != scorer.docID()) {
// The random approximation can return doc ids that are not present in the underlying
// scorer. These additional doc ids are always *before* the next matching doc so we
// cannot use them to shallow advance the main scorer which is already ahead.
target = scorer.docID();
}
return scorer.advanceShallow(target);
}
@ -120,12 +126,12 @@ public class RandomApproximationQuery extends Query {
@Override
public int docID() {
return scorer.docID();
return twoPhaseView.approximation().docID();
}
@Override
public DocIdSetIterator iterator() {
return scorer.iterator();
return TwoPhaseIterator.asDocIdSetIterator(twoPhaseView);
}
}