LUCENE-10418: Optimize `Query#rewrite` in the non-scoring case. (#672)

This commit is contained in:
Adrien Grand 2022-03-17 16:41:55 +01:00 committed by GitHub
parent 86bd921fce
commit 8fb6543280
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 226 additions and 101 deletions

View File

@ -35,7 +35,9 @@ Improvements
Optimizations
---------------------
(No changes)
* LUCENE-10418: More `Query#rewrite` optimizations for the non-scoring case.
(Adrien Grand)
Bug Fixes
---------------------

View File

@ -191,38 +191,41 @@ public class BooleanQuery extends Query implements Iterable<BooleanClause> {
return clauses.iterator();
}
private BooleanQuery rewriteNoScoring() {
boolean keepShould =
// Utility method for rewriting BooleanQuery when scores are not needed.
// This is called from ConstantScoreQuery#rewrite
BooleanQuery rewriteNoScoring(IndexReader reader) throws IOException {
boolean actuallyRewritten = false;
BooleanQuery.Builder newQuery =
new BooleanQuery.Builder().setMinimumNumberShouldMatch(getMinimumNumberShouldMatch());
final boolean keepShould =
getMinimumNumberShouldMatch() > 0
|| (clauseSets.get(Occur.MUST).size() + clauseSets.get(Occur.FILTER).size() == 0);
if (clauseSets.get(Occur.MUST).size() == 0 && keepShould) {
return this;
}
BooleanQuery.Builder newQuery = new BooleanQuery.Builder();
newQuery.setMinimumNumberShouldMatch(getMinimumNumberShouldMatch());
for (BooleanClause clause : clauses) {
switch (clause.getOccur()) {
case MUST:
{
newQuery.add(clause.getQuery(), Occur.FILTER);
break;
}
case SHOULD:
{
if (keepShould) {
newQuery.add(clause);
}
break;
}
case FILTER:
case MUST_NOT:
default:
{
newQuery.add(clause);
}
Query query = clause.getQuery();
Query rewritten = new ConstantScoreQuery(query).rewrite(reader);
if (rewritten instanceof ConstantScoreQuery) {
rewritten = ((ConstantScoreQuery) rewritten).getQuery();
}
BooleanClause.Occur occur = clause.getOccur();
if (occur == Occur.SHOULD && keepShould == false) {
// ignore clause
actuallyRewritten = true;
} else if (occur == Occur.MUST) {
// replace MUST clauses with FILTER clauses
newQuery.add(rewritten, Occur.FILTER);
actuallyRewritten = true;
} else if (query != rewritten) {
newQuery.add(rewritten, occur);
actuallyRewritten = true;
} else {
newQuery.add(clause);
}
}
if (actuallyRewritten == false) {
return this;
}
return newQuery.build();
@ -231,11 +234,7 @@ public class BooleanQuery extends Query implements Iterable<BooleanClause> {
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
BooleanQuery query = this;
if (scoreMode.needsScores() == false) {
query = rewriteNoScoring();
}
return new BooleanWeight(query, searcher, scoreMode, boost);
return new BooleanWeight(this, searcher, scoreMode, boost);
}
@Override
@ -274,12 +273,22 @@ public class BooleanQuery extends Query implements Iterable<BooleanClause> {
boolean actuallyRewritten = false;
for (BooleanClause clause : this) {
Query query = clause.getQuery();
Query rewritten = query.rewrite(reader);
BooleanClause.Occur occur = clause.getOccur();
Query rewritten;
if (occur == Occur.FILTER || occur == Occur.MUST_NOT) {
// Clauses that are not involved in scoring can get some extra simplifications
rewritten = new ConstantScoreQuery(query).rewrite(reader);
if (rewritten instanceof ConstantScoreQuery) {
rewritten = ((ConstantScoreQuery) rewritten).getQuery();
}
} else {
rewritten = query.rewrite(reader);
}
if (rewritten != query || query.getClass() == MatchNoDocsQuery.class) {
// rewrite clause
actuallyRewritten = true;
if (rewritten.getClass() == MatchNoDocsQuery.class) {
switch (clause.getOccur()) {
switch (occur) {
case SHOULD:
case MUST_NOT:
// the clause can be safely ignored
@ -289,7 +298,7 @@ public class BooleanQuery extends Query implements Iterable<BooleanClause> {
return rewritten;
}
} else {
builder.add(rewritten, clause.getOccur());
builder.add(rewritten, occur);
}
} else {
// leave as-is

View File

@ -411,6 +411,13 @@ final class BooleanWeight extends Weight {
return null;
}
if (scoreMode.needsScores() == false
&& minShouldMatch == 0
&& scorers.get(Occur.MUST).size() + scorers.get(Occur.FILTER).size() > 0) {
// Purely optional clauses are useless without scoring.
scorers.get(Occur.SHOULD).clear();
}
return new Boolean2ScorerSupplier(this, scorers, scoreMode, minShouldMatch);
}
}

View File

@ -43,6 +43,16 @@ public final class ConstantScoreQuery extends Query {
public Query rewrite(IndexReader reader) throws IOException {
Query rewritten = query.rewrite(reader);
// Do some extra simplifications that are legal since scores are not needed on the wrapped
// query.
if (rewritten instanceof BoostQuery) {
rewritten = ((BoostQuery) rewritten).getQuery();
} else if (rewritten instanceof ConstantScoreQuery) {
rewritten = ((ConstantScoreQuery) rewritten).getQuery();
} else if (rewritten instanceof BooleanQuery) {
rewritten = ((BooleanQuery) rewritten).rewriteNoScoring(reader);
}
if (rewritten.getClass() == MatchNoDocsQuery.class) {
// bubble up MatchNoDocsQuery
return rewritten;

View File

@ -446,7 +446,7 @@ public class IndexSearcher {
* possible.
*/
public int count(Query query) throws IOException {
query = rewrite(query);
query = rewrite(query, false);
final Weight weight = createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1);
final CollectorManager<ShortcutHitCountCollector, Integer> shortcutCollectorManager =
@ -551,7 +551,7 @@ public class IndexSearcher {
* clauses.
*/
public void search(Query query, Collector results) throws IOException {
query = rewrite(query);
query = rewrite(query, results.scoreMode().needsScores());
search(leafContexts, createWeight(query, results.scoreMode(), 1), results);
}
@ -682,7 +682,7 @@ public class IndexSearcher {
public <C extends Collector, T> T search(Query query, CollectorManager<C, T> collectorManager)
throws IOException {
final C firstCollector = collectorManager.newCollector();
query = rewrite(query);
query = rewrite(query, firstCollector.scoreMode().needsScores());
final Weight weight = createWeight(query, firstCollector.scoreMode(), 1);
return search(weight, collectorManager, firstCollector);
}
@ -795,6 +795,15 @@ public class IndexSearcher {
return query;
}
private Query rewrite(Query original, boolean needsScores) throws IOException {
if (needsScores) {
return rewrite(original);
} else {
// Take advantage of the few extra rewrite rules of ConstantScoreQuery.
return rewrite(new ConstantScoreQuery(original));
}
}
/**
* Returns a QueryVisitor which recursively checks the total number of clauses that a query and
* its children cumulatively have and validates that the total number does not exceed the

View File

@ -19,6 +19,7 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.lucene.document.Document;
@ -405,7 +406,7 @@ public class TestBooleanRewrites extends LuceneTestCase {
final int iters = atLeast(1000);
for (int i = 0; i < iters; ++i) {
Query query = randomQuery();
Query query = randomBooleanQuery(random());
final TopDocs td1 = searcher1.search(query, 100);
final TopDocs td2 = searcher2.search(query, 100);
assertEquals(td1, td2);
@ -415,29 +416,41 @@ public class TestBooleanRewrites extends LuceneTestCase {
dir.close();
}
private Query randomBooleanQuery() {
if (random().nextInt(10) == 0) {
return new BoostQuery(randomBooleanQuery(), TestUtil.nextInt(random(), 1, 10));
}
final int numClauses = random().nextInt(5);
private Query randomBooleanQuery(Random random) {
final int numClauses = random.nextInt(5);
BooleanQuery.Builder b = new BooleanQuery.Builder();
int numShoulds = 0;
for (int i = 0; i < numClauses; ++i) {
final Occur occur = Occur.values()[random().nextInt(Occur.values().length)];
final Occur occur = Occur.values()[random.nextInt(Occur.values().length)];
if (occur == Occur.SHOULD) {
numShoulds++;
}
final Query query = randomQuery();
final Query query = randomQuery(random);
b.add(query, occur);
}
b.setMinimumNumberShouldMatch(
random().nextBoolean() ? 0 : TestUtil.nextInt(random(), 0, numShoulds + 1));
return b.build();
Query query = b.build();
if (random.nextBoolean()) {
query = randomWrapper(random, query);
}
return query;
}
private Query randomQuery() {
if (random().nextInt(10) == 0) {
return new BoostQuery(randomBooleanQuery(), TestUtil.nextInt(random(), 1, 10));
private Query randomWrapper(Random random, Query query) {
switch (random.nextInt(2)) {
case 0:
return new BoostQuery(query, TestUtil.nextInt(random, 0, 4));
case 1:
return new ConstantScoreQuery(query);
default:
throw new AssertionError();
}
}
private Query randomQuery(Random random) {
if (random.nextInt(5) == 0) {
return randomWrapper(random, randomQuery(random));
}
switch (random().nextInt(6)) {
case 0:
@ -451,7 +464,7 @@ public class TestBooleanRewrites extends LuceneTestCase {
case 4:
return new TermQuery(new Term("body", "d"));
case 5:
return randomBooleanQuery();
return randomBooleanQuery(random);
default:
throw new AssertionError();
}
@ -609,59 +622,57 @@ public class TestBooleanRewrites extends LuceneTestCase {
}
public void testDiscardShouldClauses() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
Field f = newTextField("field", "a", Field.Store.NO);
doc.add(f);
w.addDocument(doc);
w.commit();
IndexSearcher searcher = newSearcher(new MultiReader());
DirectoryReader reader = w.getReader();
final IndexSearcher searcher = new IndexSearcher(reader);
Query query1 =
new ConstantScoreQuery(
new BooleanQuery.Builder()
.add(new TermQuery(new Term("field", "a")), Occur.MUST)
.add(new TermQuery(new Term("field", "b")), Occur.SHOULD)
.build());
Query rewritten1 = new ConstantScoreQuery(new TermQuery(new Term("field", "a")));
assertEquals(rewritten1, searcher.rewrite(query1));
BooleanQuery.Builder query1 = new BooleanQuery.Builder();
query1.add(new TermQuery(new Term("field", "a")), Occur.MUST);
query1.add(new TermQuery(new Term("field", "b")), Occur.SHOULD);
Query query2 =
new ConstantScoreQuery(
new BooleanQuery.Builder()
.add(new TermQuery(new Term("field", "a")), Occur.MUST)
.add(new TermQuery(new Term("field", "b")), Occur.SHOULD)
.add(new TermQuery(new Term("field", "c")), Occur.FILTER)
.build());
Query rewritten2 =
new ConstantScoreQuery(
new BooleanQuery.Builder()
.add(new TermQuery(new Term("field", "a")), Occur.FILTER)
.add(new TermQuery(new Term("field", "c")), Occur.FILTER)
.build());
assertEquals(rewritten2, searcher.rewrite(query2));
query1.setMinimumNumberShouldMatch(0);
Query query3 =
new ConstantScoreQuery(
new BooleanQuery.Builder()
.add(new TermQuery(new Term("field", "a")), Occur.SHOULD)
.add(new TermQuery(new Term("field", "b")), Occur.SHOULD)
.build());
assertSame(query3, searcher.rewrite(query3));
Weight weight =
searcher.createWeight(searcher.rewrite(query1.build()), ScoreMode.COMPLETE_NO_SCORES, 1);
Query query4 =
new ConstantScoreQuery(
new BooleanQuery.Builder()
.add(new TermQuery(new Term("field", "a")), Occur.SHOULD)
.add(new TermQuery(new Term("field", "b")), Occur.MUST_NOT)
.build());
assertSame(query4, searcher.rewrite(query4));
Query rewrittenQuery1 = weight.getQuery();
assertTrue(rewrittenQuery1 instanceof BooleanQuery);
BooleanQuery booleanRewrittenQuery1 = (BooleanQuery) rewrittenQuery1;
for (BooleanClause clause : booleanRewrittenQuery1.clauses()) {
assertNotEquals(clause.getOccur(), Occur.SHOULD);
}
BooleanQuery.Builder query2 = new BooleanQuery.Builder();
query2.add(new TermQuery(new Term("field", "a")), Occur.MUST);
query2.add(new TermQuery(new Term("field", "b")), Occur.SHOULD);
query2.add(new TermQuery(new Term("field", "c")), Occur.FILTER);
query2.setMinimumNumberShouldMatch(0);
weight =
searcher.createWeight(searcher.rewrite(query2.build()), ScoreMode.COMPLETE_NO_SCORES, 1);
Query rewrittenQuery2 = weight.getQuery();
assertTrue(rewrittenQuery2 instanceof BooleanQuery);
BooleanQuery booleanRewrittenQuery2 = (BooleanQuery) rewrittenQuery1;
for (BooleanClause clause : booleanRewrittenQuery2.clauses()) {
assertNotEquals(clause.getOccur(), Occur.SHOULD);
}
reader.close();
w.close();
dir.close();
Query query5 =
new ConstantScoreQuery(
new BooleanQuery.Builder()
.setMinimumNumberShouldMatch(1)
.add(new TermQuery(new Term("field", "a")), Occur.SHOULD)
.add(new TermQuery(new Term("field", "b")), Occur.SHOULD)
.add(new TermQuery(new Term("field", "c")), Occur.FILTER)
.build());
assertSame(query5, searcher.rewrite(query5));
}
public void testShouldMatchNoDocsQuery() throws IOException {
@ -713,4 +724,63 @@ public class TestBooleanRewrites extends LuceneTestCase {
BooleanQuery query = new BooleanQuery.Builder().build();
assertEquals(new MatchNoDocsQuery(), searcher.rewrite(query));
}
public void testSimplifyFilterClauses() throws IOException {
IndexSearcher searcher = newSearcher(new MultiReader());
BooleanQuery query1 =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.MUST)
.add(new ConstantScoreQuery(new TermQuery(new Term("foo", "baz"))), Occur.FILTER)
.build();
BooleanQuery expected1 =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.MUST)
.add(new TermQuery(new Term("foo", "baz")), Occur.FILTER)
.build();
assertEquals(expected1, searcher.rewrite(query1));
BooleanQuery query2 =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.FILTER)
.add(new ConstantScoreQuery(new TermQuery(new Term("foo", "bar"))), Occur.FILTER)
.build();
Query expected2 =
new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "bar"))), 0);
assertEquals(expected2, searcher.rewrite(query2));
}
public void testSimplifyMustNotClauses() throws IOException {
IndexSearcher searcher = newSearcher(new MultiReader());
BooleanQuery query =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.MUST)
.add(new ConstantScoreQuery(new TermQuery(new Term("foo", "baz"))), Occur.MUST_NOT)
.build();
BooleanQuery expected =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.MUST)
.add(new TermQuery(new Term("foo", "baz")), Occur.MUST_NOT)
.build();
assertEquals(expected, searcher.rewrite(query));
}
public void testSimplifyNonScoringShouldClauses() throws IOException {
IndexSearcher searcher = newSearcher(new MultiReader());
Query query =
new ConstantScoreQuery(
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD)
.add(new ConstantScoreQuery(new TermQuery(new Term("foo", "baz"))), Occur.SHOULD)
.build());
Query expected =
new ConstantScoreQuery(
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "baz")), Occur.SHOULD)
.build());
assertEquals(expected, searcher.rewrite(query));
}
}

View File

@ -91,5 +91,8 @@ public class TestBoostQuery extends LuceneTestCase {
Query query = new BoostQuery(new MatchNoDocsQuery(), 2f);
assertEquals(new MatchNoDocsQuery(), searcher.rewrite(query));
query = new BoostQuery(new MatchNoDocsQuery(), 0f);
assertEquals(new MatchNoDocsQuery(), searcher.rewrite(query));
}
}

View File

@ -1971,7 +1971,19 @@ public class TestLRUQueryCache extends LuceneTestCase {
w.addDocuments(Arrays.asList(doc1, doc2, doc3));
final IndexReader reader = w.getReader();
final IndexSearcher searcher = newSearcher(reader);
final UsageTrackingQueryCachingPolicy policy = new UsageTrackingQueryCachingPolicy();
final QueryCachingPolicy policy =
new QueryCachingPolicy() {
@Override
public boolean shouldCache(Query query) throws IOException {
return query.getClass() != TermQuery.class;
}
@Override
public void onUse(Query query) {
// no-op
}
};
searcher.setQueryCachingPolicy(policy);
w.close();

View File

@ -46,6 +46,9 @@ public class TestNeedsScores extends LuceneTestCase {
}
reader = iw.getReader();
searcher = newSearcher(reader);
// Needed so that the cache doesn't consume weights with ScoreMode.COMPLETE_NO_SCORES for the
// purpose of populating the cache.
searcher.setQueryCache(null);
iw.close();
}