LUCENE-10480: Use BulkScorer to limit BMMScorer to only top-level disjunctions (#1018)

This commit is contained in:
Zach Chen 2022-07-19 18:59:19 -07:00 committed by GitHub
parent 3d7d85f245
commit 28ce8abb51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 293 additions and 68 deletions

View File

@ -118,21 +118,6 @@ final class Boolean2ScorerSupplier extends ScorerSupplier {
leadCost);
}
// pure two terms disjunction
if (scoreMode == ScoreMode.TOP_SCORES
&& minShouldMatch <= 1
&& subs.get(Occur.FILTER).isEmpty()
&& subs.get(Occur.MUST).isEmpty()
&& subs.get(Occur.MUST_NOT).isEmpty()
&& subs.get(Occur.SHOULD).size() == 2) {
final List<Scorer> optionalScorers = new ArrayList<>();
for (ScorerSupplier scorer : subs.get(Occur.SHOULD)) {
optionalScorers.add(scorer.get(leadCost));
}
return new BlockMaxMaxscoreScorer(weight, optionalScorers);
}
// pure disjunction
if (subs.get(Occur.FILTER).isEmpty() && subs.get(Occur.MUST).isEmpty()) {
return excl(

View File

@ -34,7 +34,7 @@ final class BooleanWeight extends Weight {
final BooleanQuery query;
private static class WeightedBooleanClause {
protected static class WeightedBooleanClause {
final BooleanClause clause;
final Weight weight;
@ -191,6 +191,63 @@ final class BooleanWeight extends Weight {
// or null if it is not applicable
// pkg-private for forcing use of BooleanScorer in tests
BulkScorer optionalBulkScorer(LeafReaderContext context) throws IOException {
if (scoreMode == ScoreMode.TOP_SCORES) {
if (!query.isPureDisjunction() || weightedClauses.size() > 2) {
return null;
}
List<ScorerSupplier> optional = new ArrayList<>();
for (WeightedBooleanClause wc : weightedClauses) {
Weight w = wc.weight;
BooleanClause c = wc.clause;
if (c.getOccur() != Occur.SHOULD) {
continue;
}
ScorerSupplier scorer = w.scorerSupplier(context);
if (scorer != null) {
optional.add(scorer);
}
}
if (optional.size() <= 1) {
return null;
}
List<Scorer> optionalScorers = new ArrayList<>();
for (ScorerSupplier ss : optional) {
optionalScorers.add(ss.get(Long.MAX_VALUE));
}
return new BulkScorer() {
final Scorer bmmScorer = new BlockMaxMaxscoreScorer(BooleanWeight.this, optionalScorers);
final DocIdSetIterator iterator = bmmScorer.iterator();
@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max)
throws IOException {
collector.setScorer(bmmScorer);
int doc = bmmScorer.docID();
if (doc < min) {
doc = iterator.advance(min);
}
while (doc < max) {
if (acceptDocs == null || acceptDocs.get(doc)) {
collector.collect(doc);
}
doc = iterator.nextDoc();
}
return doc;
}
@Override
public long cost() {
return iterator.cost();
}
};
}
List<BulkScorer> optional = new ArrayList<BulkScorer>();
for (WeightedBooleanClause wc : weightedClauses) {
Weight w = wc.weight;
@ -329,11 +386,6 @@ final class BooleanWeight extends Weight {
@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
if (scoreMode == ScoreMode.TOP_SCORES) {
// If only the top docs are requested, use the default bulk scorer
// so that we can dynamically prune non-competitive hits.
return super.bulkScorer(context);
}
final BulkScorer bulkScorer = booleanScorer(context);
if (bulkScorer != null) {
// bulk scoring is applicable, use it

View File

@ -18,15 +18,16 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.search.AssertingScorer;
import org.apache.lucene.tests.util.LuceneTestCase;
// These basic tests are similar to some of the tests in TestWANDScorer, and may not need to be kept
@ -62,26 +63,22 @@ public class TestBlockMaxMaxscoreScorer extends LuceneTestCase {
IndexSearcher searcher = newSearcher(reader);
Query query =
new BooleanQuery.Builder()
.add(
new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
BooleanClause.Occur.SHOULD)
.add(
new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
BooleanClause.Occur.SHOULD)
.build();
new BlockMaxMaxscoreQuery(
new BooleanQuery.Builder()
.add(
new BoostQuery(
new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
BooleanClause.Occur.SHOULD)
.add(
new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
BooleanClause.Occur.SHOULD)
.build());
Scorer scorer =
searcher
.createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
.scorer(searcher.getIndexReader().leaves().get(0));
if (scorer instanceof AssertingScorer) {
assertTrue(((AssertingScorer) scorer).getIn() instanceof BlockMaxMaxscoreScorer);
} else {
assertTrue(scorer instanceof BlockMaxMaxscoreScorer);
}
assertEquals(0, scorer.iterator().nextDoc());
assertEquals(2 + 1, scorer.score(), 0);
@ -102,7 +99,7 @@ public class TestBlockMaxMaxscoreScorer extends LuceneTestCase {
}
}
public void testBasicsWithThreeDisjunctionClausesNotUseBMMScorer() throws Exception {
public void testBasicsWithThreeDisjunctionClauses() throws Exception {
try (Directory dir = newDirectory()) {
writeDocuments(dir);
@ -110,29 +107,26 @@ public class TestBlockMaxMaxscoreScorer extends LuceneTestCase {
IndexSearcher searcher = newSearcher(reader);
Query query =
new BooleanQuery.Builder()
.add(
new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
BooleanClause.Occur.SHOULD)
.add(
new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
BooleanClause.Occur.SHOULD)
.add(
new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "C"))), 3),
BooleanClause.Occur.SHOULD)
.build();
new BlockMaxMaxscoreQuery(
new BooleanQuery.Builder()
.add(
new BoostQuery(
new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
BooleanClause.Occur.SHOULD)
.add(
new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
BooleanClause.Occur.SHOULD)
.add(
new BoostQuery(
new ConstantScoreQuery(new TermQuery(new Term("foo", "C"))), 3),
BooleanClause.Occur.SHOULD)
.build());
Scorer scorer =
searcher
.createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
.scorer(searcher.getIndexReader().leaves().get(0));
if (scorer instanceof AssertingScorer) {
assertTrue(((AssertingScorer) scorer).getIn() instanceof WANDScorer);
} else {
assertTrue(scorer instanceof WANDScorer);
}
assertEquals(0, scorer.iterator().nextDoc());
assertEquals(2 + 1, scorer.score(), 0);
@ -163,15 +157,16 @@ public class TestBlockMaxMaxscoreScorer extends LuceneTestCase {
Query query =
new BooleanQuery.Builder()
.add(
new BooleanQuery.Builder()
.add(
new BoostQuery(
new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
BooleanClause.Occur.SHOULD)
.add(
new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
BooleanClause.Occur.SHOULD)
.build(),
new BlockMaxMaxscoreQuery(
new BooleanQuery.Builder()
.add(
new BoostQuery(
new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
BooleanClause.Occur.SHOULD)
.add(
new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
BooleanClause.Occur.SHOULD)
.build()),
BooleanClause.Occur.MUST)
.add(new TermQuery(new Term("foo", "C")), BooleanClause.Occur.FILTER)
.build();
@ -214,11 +209,17 @@ public class TestBlockMaxMaxscoreScorer extends LuceneTestCase {
Query query =
new BooleanQuery.Builder()
.add(
new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
BooleanClause.Occur.SHOULD)
.add(
new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
BooleanClause.Occur.SHOULD)
new BlockMaxMaxscoreQuery(
new BooleanQuery.Builder()
.add(
new BoostQuery(
new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
BooleanClause.Occur.SHOULD)
.add(
new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
BooleanClause.Occur.SHOULD)
.build()),
BooleanClause.Occur.MUST)
.add(new TermQuery(new Term("foo", "C")), BooleanClause.Occur.MUST_NOT)
.build();
@ -252,4 +253,81 @@ public class TestBlockMaxMaxscoreScorer extends LuceneTestCase {
}
}
}
private static class BlockMaxMaxscoreQuery extends Query {
private final BooleanQuery query;
private BlockMaxMaxscoreQuery(BooleanQuery query) {
assert query.isPureDisjunction()
: "This test utility query is only used to create BlockMaxMaxscoreScorer for disjunctions.";
assert query.clauses().size() >= 2
: "There must be at least two optional clauses to use this test utility query.";
this.query = query;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
return new Weight(query) {
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
// no-ops
return null;
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
BooleanWeight weight = (BooleanWeight) query.createWeight(searcher, scoreMode, boost);
List<Scorer> optionalScorers =
weight.weightedClauses.stream()
.map(wc -> wc.weight)
.map(
w -> {
try {
return w.scorerSupplier(context);
} catch (IOException e) {
throw new AssertionError(e);
}
})
.map(
ss -> {
try {
return ss.get(Long.MAX_VALUE);
} catch (IOException e) {
throw new AssertionError(e);
}
})
.toList();
return new BlockMaxMaxscoreScorer(weight, optionalScorers);
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
};
}
@Override
public String toString(String field) {
return "BlockMaxMaxscoreQuery";
}
@Override
public void visit(QueryVisitor visitor) {
// no-ops
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) && query.equals(((BlockMaxMaxscoreQuery) other).query);
}
@Override
public int hashCode() {
return 31 * classHash() + query.hashCode();
}
}
}

View File

@ -16,6 +16,9 @@
*/
package org.apache.lucene.search;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomBoolean;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import com.carrotsearch.randomizedtesting.generators.RandomPicks;
import java.io.IOException;
import java.util.ArrayList;
@ -905,6 +908,113 @@ public class TestBooleanQuery extends LuceneTestCase {
dir.close();
}
// test BlockMaxMaxscoreScorer
public void testDisjunctionTwoClausesMatchesCountAndScore() throws Exception {
List<String[]> docContent =
Arrays.asList(
new String[] {"A", "B"}, // 0
new String[] {"A"}, // 1
new String[] {}, // 2
new String[] {"A", "B", "C"}, // 3
new String[] {"B"}, // 4
new String[] {"B", "C"} // 5
);
// result sorted by score
int[][] matchDocScore = {
{0, 2 + 1},
{3, 2 + 1},
{1, 2},
{4, 1},
{5, 1}
};
try (Directory dir = newDirectory()) {
try (IndexWriter w =
new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) {
for (String[] values : docContent) {
Document doc = new Document();
for (String value : values) {
doc.add(new StringField("foo", value, Field.Store.NO));
}
w.addDocument(doc);
}
w.forceMerge(1);
}
try (IndexReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = newSearcher(reader);
Query query =
new BooleanQuery.Builder()
.add(
new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
BooleanClause.Occur.SHOULD)
.add(
new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
BooleanClause.Occur.SHOULD)
.build();
TopDocs topDocs = searcher.search(query, 10);
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
ScoreDoc scoreDoc = topDocs.scoreDocs[i];
assertEquals(matchDocScore[i][0], scoreDoc.doc);
assertEquals(matchDocScore[i][1], scoreDoc.score, 0);
}
}
}
}
public void testDisjunctionRandomClausesMatchesCount() throws Exception {
int numFieldValue = RandomNumbers.randomIntBetween(random(), 1, 10);
int[] numDocsPerFieldValue = new int[numFieldValue];
int allDocsCount = 0;
for (int i = 0; i < numDocsPerFieldValue.length; i++) {
int numDocs = RandomNumbers.randomIntBetween(random(), 10, 50);
numDocsPerFieldValue[i] = numDocs;
allDocsCount += numDocs;
}
try (Directory dir = newDirectory()) {
try (IndexWriter w =
new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) {
for (int i = 0; i < numFieldValue; i++) {
for (int j = 0; j < numDocsPerFieldValue[i]; j++) {
Document doc = new Document();
doc.add(new StringField("field", String.valueOf(i), Field.Store.NO));
w.addDocument(doc);
}
}
w.forceMerge(1);
}
int matchedDocsCount = 0;
try (IndexReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = newSearcher(reader);
BooleanQuery.Builder builder = new BooleanQuery.Builder();
for (int i = 0; i < numFieldValue; i++) {
if (randomBoolean()) {
matchedDocsCount += numDocsPerFieldValue[i];
builder.add(
new TermQuery(new Term("field", String.valueOf(i))), BooleanClause.Occur.SHOULD);
}
}
Query query = builder.build();
TopDocs topDocs = searcher.search(query, allDocsCount);
assertEquals(matchedDocsCount, topDocs.scoreDocs.length);
}
}
}
public void testProhibitedMatchesCount() throws IOException {
Directory dir = newDirectory();
IndexWriter writer = new IndexWriter(dir, new IndexWriterConfig());