diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 9ff8cdd2664..f424ec19cfd 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -35,7 +35,9 @@ Improvements Optimizations --------------------- -(No changes) + +* LUCENE-10418: More `Query#rewrite` optimizations for the non-scoring case. + (Adrien Grand) Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java index 1b066d7ad5e..165bf12de9f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java @@ -191,38 +191,41 @@ public class BooleanQuery extends Query implements Iterable { return clauses.iterator(); } - private BooleanQuery rewriteNoScoring() { - boolean keepShould = + // Utility method for rewriting BooleanQuery when scores are not needed. + // This is called from ConstantScoreQuery#rewrite + BooleanQuery rewriteNoScoring(IndexReader reader) throws IOException { + boolean actuallyRewritten = false; + BooleanQuery.Builder newQuery = + new BooleanQuery.Builder().setMinimumNumberShouldMatch(getMinimumNumberShouldMatch()); + + final boolean keepShould = getMinimumNumberShouldMatch() > 0 || (clauseSets.get(Occur.MUST).size() + clauseSets.get(Occur.FILTER).size() == 0); - if (clauseSets.get(Occur.MUST).size() == 0 && keepShould) { - return this; - } - BooleanQuery.Builder newQuery = new BooleanQuery.Builder(); - - newQuery.setMinimumNumberShouldMatch(getMinimumNumberShouldMatch()); for (BooleanClause clause : clauses) { - switch (clause.getOccur()) { - case MUST: - { - newQuery.add(clause.getQuery(), Occur.FILTER); - break; - } - case SHOULD: - { - if (keepShould) { - newQuery.add(clause); - } - break; - } - case FILTER: - case MUST_NOT: - default: - { - newQuery.add(clause); - } + Query query = clause.getQuery(); + Query rewritten = new ConstantScoreQuery(query).rewrite(reader); + if (rewritten instanceof ConstantScoreQuery) { + rewritten = ((ConstantScoreQuery) rewritten).getQuery(); } + BooleanClause.Occur occur = clause.getOccur(); + if (occur == Occur.SHOULD && keepShould == false) { + // ignore clause + actuallyRewritten = true; + } else if (occur == Occur.MUST) { + // replace MUST clauses with FILTER clauses + newQuery.add(rewritten, Occur.FILTER); + actuallyRewritten = true; + } else if (query != rewritten) { + newQuery.add(rewritten, occur); + actuallyRewritten = true; + } else { + newQuery.add(clause); + } + } + + if (actuallyRewritten == false) { + return this; } return newQuery.build(); @@ -231,11 +234,7 @@ public class BooleanQuery extends Query implements Iterable { @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - BooleanQuery query = this; - if (scoreMode.needsScores() == false) { - query = rewriteNoScoring(); - } - return new BooleanWeight(query, searcher, scoreMode, boost); + return new BooleanWeight(this, searcher, scoreMode, boost); } @Override @@ -274,12 +273,22 @@ public class BooleanQuery extends Query implements Iterable { boolean actuallyRewritten = false; for (BooleanClause clause : this) { Query query = clause.getQuery(); - Query rewritten = query.rewrite(reader); + BooleanClause.Occur occur = clause.getOccur(); + Query rewritten; + if (occur == Occur.FILTER || occur == Occur.MUST_NOT) { + // Clauses that are not involved in scoring can get some extra simplifications + rewritten = new ConstantScoreQuery(query).rewrite(reader); + if (rewritten instanceof ConstantScoreQuery) { + rewritten = ((ConstantScoreQuery) rewritten).getQuery(); + } + } else { + rewritten = query.rewrite(reader); + } if (rewritten != query || query.getClass() == MatchNoDocsQuery.class) { // rewrite clause actuallyRewritten = true; if (rewritten.getClass() == MatchNoDocsQuery.class) { - switch (clause.getOccur()) { + switch (occur) { case SHOULD: case MUST_NOT: // the clause can be safely ignored @@ -289,7 +298,7 @@ public class BooleanQuery extends Query implements Iterable { return rewritten; } } else { - builder.add(rewritten, clause.getOccur()); + builder.add(rewritten, occur); } } else { // leave as-is diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java b/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java index dffdbf7f6aa..2fb9c4515b6 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java @@ -411,6 +411,13 @@ final class BooleanWeight extends Weight { return null; } + if (scoreMode.needsScores() == false + && minShouldMatch == 0 + && scorers.get(Occur.MUST).size() + scorers.get(Occur.FILTER).size() > 0) { + // Purely optional clauses are useless without scoring. + scorers.get(Occur.SHOULD).clear(); + } + return new Boolean2ScorerSupplier(this, scorers, scoreMode, minShouldMatch); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java index 2f75a568793..c8a0b2ecc37 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java @@ -43,6 +43,16 @@ public final class ConstantScoreQuery extends Query { public Query rewrite(IndexReader reader) throws IOException { Query rewritten = query.rewrite(reader); + // Do some extra simplifications that are legal since scores are not needed on the wrapped + // query. + if (rewritten instanceof BoostQuery) { + rewritten = ((BoostQuery) rewritten).getQuery(); + } else if (rewritten instanceof ConstantScoreQuery) { + rewritten = ((ConstantScoreQuery) rewritten).getQuery(); + } else if (rewritten instanceof BooleanQuery) { + rewritten = ((BooleanQuery) rewritten).rewriteNoScoring(reader); + } + if (rewritten.getClass() == MatchNoDocsQuery.class) { // bubble up MatchNoDocsQuery return rewritten; diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java index 5d47bbe7a99..4e64089f3c1 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java @@ -446,7 +446,7 @@ public class IndexSearcher { * possible. */ public int count(Query query) throws IOException { - query = rewrite(query); + query = rewrite(query, false); final Weight weight = createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1); final CollectorManager shortcutCollectorManager = @@ -551,7 +551,7 @@ public class IndexSearcher { * clauses. */ public void search(Query query, Collector results) throws IOException { - query = rewrite(query); + query = rewrite(query, results.scoreMode().needsScores()); search(leafContexts, createWeight(query, results.scoreMode(), 1), results); } @@ -682,7 +682,7 @@ public class IndexSearcher { public T search(Query query, CollectorManager collectorManager) throws IOException { final C firstCollector = collectorManager.newCollector(); - query = rewrite(query); + query = rewrite(query, firstCollector.scoreMode().needsScores()); final Weight weight = createWeight(query, firstCollector.scoreMode(), 1); return search(weight, collectorManager, firstCollector); } @@ -795,6 +795,15 @@ public class IndexSearcher { return query; } + private Query rewrite(Query original, boolean needsScores) throws IOException { + if (needsScores) { + return rewrite(original); + } else { + // Take advantage of the few extra rewrite rules of ConstantScoreQuery. + return rewrite(new ConstantScoreQuery(original)); + } + } + /** * Returns a QueryVisitor which recursively checks the total number of clauses that a query and * its children cumulatively have and validates that the total number does not exceed the diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java index 93da1ba69c6..8ffb04bced3 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java @@ -19,6 +19,7 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.Arrays; import java.util.Map; +import java.util.Random; import java.util.Set; import java.util.stream.Collectors; import org.apache.lucene.document.Document; @@ -405,7 +406,7 @@ public class TestBooleanRewrites extends LuceneTestCase { final int iters = atLeast(1000); for (int i = 0; i < iters; ++i) { - Query query = randomQuery(); + Query query = randomBooleanQuery(random()); final TopDocs td1 = searcher1.search(query, 100); final TopDocs td2 = searcher2.search(query, 100); assertEquals(td1, td2); @@ -415,29 +416,41 @@ public class TestBooleanRewrites extends LuceneTestCase { dir.close(); } - private Query randomBooleanQuery() { - if (random().nextInt(10) == 0) { - return new BoostQuery(randomBooleanQuery(), TestUtil.nextInt(random(), 1, 10)); - } - final int numClauses = random().nextInt(5); + private Query randomBooleanQuery(Random random) { + final int numClauses = random.nextInt(5); BooleanQuery.Builder b = new BooleanQuery.Builder(); int numShoulds = 0; for (int i = 0; i < numClauses; ++i) { - final Occur occur = Occur.values()[random().nextInt(Occur.values().length)]; + final Occur occur = Occur.values()[random.nextInt(Occur.values().length)]; if (occur == Occur.SHOULD) { numShoulds++; } - final Query query = randomQuery(); + final Query query = randomQuery(random); b.add(query, occur); } b.setMinimumNumberShouldMatch( random().nextBoolean() ? 0 : TestUtil.nextInt(random(), 0, numShoulds + 1)); - return b.build(); + Query query = b.build(); + if (random.nextBoolean()) { + query = randomWrapper(random, query); + } + return query; } - private Query randomQuery() { - if (random().nextInt(10) == 0) { - return new BoostQuery(randomBooleanQuery(), TestUtil.nextInt(random(), 1, 10)); + private Query randomWrapper(Random random, Query query) { + switch (random.nextInt(2)) { + case 0: + return new BoostQuery(query, TestUtil.nextInt(random, 0, 4)); + case 1: + return new ConstantScoreQuery(query); + default: + throw new AssertionError(); + } + } + + private Query randomQuery(Random random) { + if (random.nextInt(5) == 0) { + return randomWrapper(random, randomQuery(random)); } switch (random().nextInt(6)) { case 0: @@ -451,7 +464,7 @@ public class TestBooleanRewrites extends LuceneTestCase { case 4: return new TermQuery(new Term("body", "d")); case 5: - return randomBooleanQuery(); + return randomBooleanQuery(random); default: throw new AssertionError(); } @@ -609,59 +622,57 @@ public class TestBooleanRewrites extends LuceneTestCase { } public void testDiscardShouldClauses() throws IOException { - Directory dir = newDirectory(); - RandomIndexWriter w = new RandomIndexWriter(random(), dir); - Document doc = new Document(); - Field f = newTextField("field", "a", Field.Store.NO); - doc.add(f); - w.addDocument(doc); - w.commit(); + IndexSearcher searcher = newSearcher(new MultiReader()); - DirectoryReader reader = w.getReader(); - final IndexSearcher searcher = new IndexSearcher(reader); + Query query1 = + new ConstantScoreQuery( + new BooleanQuery.Builder() + .add(new TermQuery(new Term("field", "a")), Occur.MUST) + .add(new TermQuery(new Term("field", "b")), Occur.SHOULD) + .build()); + Query rewritten1 = new ConstantScoreQuery(new TermQuery(new Term("field", "a"))); + assertEquals(rewritten1, searcher.rewrite(query1)); - BooleanQuery.Builder query1 = new BooleanQuery.Builder(); - query1.add(new TermQuery(new Term("field", "a")), Occur.MUST); - query1.add(new TermQuery(new Term("field", "b")), Occur.SHOULD); + Query query2 = + new ConstantScoreQuery( + new BooleanQuery.Builder() + .add(new TermQuery(new Term("field", "a")), Occur.MUST) + .add(new TermQuery(new Term("field", "b")), Occur.SHOULD) + .add(new TermQuery(new Term("field", "c")), Occur.FILTER) + .build()); + Query rewritten2 = + new ConstantScoreQuery( + new BooleanQuery.Builder() + .add(new TermQuery(new Term("field", "a")), Occur.FILTER) + .add(new TermQuery(new Term("field", "c")), Occur.FILTER) + .build()); + assertEquals(rewritten2, searcher.rewrite(query2)); - query1.setMinimumNumberShouldMatch(0); + Query query3 = + new ConstantScoreQuery( + new BooleanQuery.Builder() + .add(new TermQuery(new Term("field", "a")), Occur.SHOULD) + .add(new TermQuery(new Term("field", "b")), Occur.SHOULD) + .build()); + assertSame(query3, searcher.rewrite(query3)); - Weight weight = - searcher.createWeight(searcher.rewrite(query1.build()), ScoreMode.COMPLETE_NO_SCORES, 1); + Query query4 = + new ConstantScoreQuery( + new BooleanQuery.Builder() + .add(new TermQuery(new Term("field", "a")), Occur.SHOULD) + .add(new TermQuery(new Term("field", "b")), Occur.MUST_NOT) + .build()); + assertSame(query4, searcher.rewrite(query4)); - Query rewrittenQuery1 = weight.getQuery(); - - assertTrue(rewrittenQuery1 instanceof BooleanQuery); - - BooleanQuery booleanRewrittenQuery1 = (BooleanQuery) rewrittenQuery1; - - for (BooleanClause clause : booleanRewrittenQuery1.clauses()) { - assertNotEquals(clause.getOccur(), Occur.SHOULD); - } - - BooleanQuery.Builder query2 = new BooleanQuery.Builder(); - query2.add(new TermQuery(new Term("field", "a")), Occur.MUST); - query2.add(new TermQuery(new Term("field", "b")), Occur.SHOULD); - query2.add(new TermQuery(new Term("field", "c")), Occur.FILTER); - - query2.setMinimumNumberShouldMatch(0); - - weight = - searcher.createWeight(searcher.rewrite(query2.build()), ScoreMode.COMPLETE_NO_SCORES, 1); - - Query rewrittenQuery2 = weight.getQuery(); - - assertTrue(rewrittenQuery2 instanceof BooleanQuery); - - BooleanQuery booleanRewrittenQuery2 = (BooleanQuery) rewrittenQuery1; - - for (BooleanClause clause : booleanRewrittenQuery2.clauses()) { - assertNotEquals(clause.getOccur(), Occur.SHOULD); - } - - reader.close(); - w.close(); - dir.close(); + Query query5 = + new ConstantScoreQuery( + new BooleanQuery.Builder() + .setMinimumNumberShouldMatch(1) + .add(new TermQuery(new Term("field", "a")), Occur.SHOULD) + .add(new TermQuery(new Term("field", "b")), Occur.SHOULD) + .add(new TermQuery(new Term("field", "c")), Occur.FILTER) + .build()); + assertSame(query5, searcher.rewrite(query5)); } public void testShouldMatchNoDocsQuery() throws IOException { @@ -713,4 +724,63 @@ public class TestBooleanRewrites extends LuceneTestCase { BooleanQuery query = new BooleanQuery.Builder().build(); assertEquals(new MatchNoDocsQuery(), searcher.rewrite(query)); } + + public void testSimplifyFilterClauses() throws IOException { + IndexSearcher searcher = newSearcher(new MultiReader()); + + BooleanQuery query1 = + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new ConstantScoreQuery(new TermQuery(new Term("foo", "baz"))), Occur.FILTER) + .build(); + BooleanQuery expected1 = + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new TermQuery(new Term("foo", "baz")), Occur.FILTER) + .build(); + assertEquals(expected1, searcher.rewrite(query1)); + + BooleanQuery query2 = + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.FILTER) + .add(new ConstantScoreQuery(new TermQuery(new Term("foo", "bar"))), Occur.FILTER) + .build(); + Query expected2 = + new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "bar"))), 0); + assertEquals(expected2, searcher.rewrite(query2)); + } + + public void testSimplifyMustNotClauses() throws IOException { + IndexSearcher searcher = newSearcher(new MultiReader()); + + BooleanQuery query = + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new ConstantScoreQuery(new TermQuery(new Term("foo", "baz"))), Occur.MUST_NOT) + .build(); + BooleanQuery expected = + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new TermQuery(new Term("foo", "baz")), Occur.MUST_NOT) + .build(); + assertEquals(expected, searcher.rewrite(query)); + } + + public void testSimplifyNonScoringShouldClauses() throws IOException { + IndexSearcher searcher = newSearcher(new MultiReader()); + + Query query = + new ConstantScoreQuery( + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) + .add(new ConstantScoreQuery(new TermQuery(new Term("foo", "baz"))), Occur.SHOULD) + .build()); + Query expected = + new ConstantScoreQuery( + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "baz")), Occur.SHOULD) + .build()); + assertEquals(expected, searcher.rewrite(query)); + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBoostQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestBoostQuery.java index 79e3a58939a..f4f02d2969c 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestBoostQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestBoostQuery.java @@ -91,5 +91,8 @@ public class TestBoostQuery extends LuceneTestCase { Query query = new BoostQuery(new MatchNoDocsQuery(), 2f); assertEquals(new MatchNoDocsQuery(), searcher.rewrite(query)); + + query = new BoostQuery(new MatchNoDocsQuery(), 0f); + assertEquals(new MatchNoDocsQuery(), searcher.rewrite(query)); } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java b/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java index 2b7bf02a2a2..d1db1551b48 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java @@ -1971,7 +1971,19 @@ public class TestLRUQueryCache extends LuceneTestCase { w.addDocuments(Arrays.asList(doc1, doc2, doc3)); final IndexReader reader = w.getReader(); final IndexSearcher searcher = newSearcher(reader); - final UsageTrackingQueryCachingPolicy policy = new UsageTrackingQueryCachingPolicy(); + final QueryCachingPolicy policy = + new QueryCachingPolicy() { + + @Override + public boolean shouldCache(Query query) throws IOException { + return query.getClass() != TermQuery.class; + } + + @Override + public void onUse(Query query) { + // no-op + } + }; searcher.setQueryCachingPolicy(policy); w.close(); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java b/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java index c4c6b06698a..aa2f3c89212 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java @@ -46,6 +46,9 @@ public class TestNeedsScores extends LuceneTestCase { } reader = iw.getReader(); searcher = newSearcher(reader); + // Needed so that the cache doesn't consume weights with ScoreMode.COMPLETE_NO_SCORES for the + // purpose of populating the cache. + searcher.setQueryCache(null); iw.close(); }