From 44e9f5de53aefb59f8aa3f95eb348ae79280f390 Mon Sep 17 00:00:00 2001 From: Gautam Worah Date: Fri, 3 Sep 2021 00:09:38 -0700 Subject: [PATCH] LUCENE-9620 Add Weight#count(LeafReaderContext) (#242) Add a default implementation in Weight.java and add sample faster implementations in MatchAllDocsQuery, MatchNoDocsQuery, TermQuery Add tests for BooleanQuery and TermQuery Co-authored-by: Gautam Worah Co-authored-by: Adrien Grand --- lucene/CHANGES.txt | 4 + .../lucene/search/ConstantScoreQuery.java | 5 + .../apache/lucene/search/IndexSearcher.java | 100 ++++++++++-------- .../lucene/search/MatchAllDocsQuery.java | 5 + .../lucene/search/MatchNoDocsQuery.java | 5 + .../org/apache/lucene/search/TermQuery.java | 16 +++ .../java/org/apache/lucene/search/Weight.java | 21 ++++ .../lucene/search/TestBooleanQuery.java | 40 +++++++ .../lucene/search/TestFilterWeight.java | 3 +- .../apache/lucene/search/TestTermQuery.java | 29 +++++ 10 files changed, 181 insertions(+), 47 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index dca1b1be8d5..7ac563b3e53 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -143,6 +143,10 @@ API Changes * LUCENE-10027: Directory reader open API from indexCommit and leafSorter has been modified to add an extra parameter - minSupportedMajorVersion. (Mayya Sharipova) +* LUCENE-9620: Added a (sometimes) faster implementation for IndexSearcher#count that relies on the new Weight#count API. + The Weight#count API represents a cleaner way for Query classes to optimize their counting method. + (Gautam Worah, Adrien Grand) + Improvements * LUCENE-9960: Avoid unnecessary top element replacement for equal elements in PriorityQueue. (Dawid Weiss) diff --git a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java index 015c82108d9..4be41494092 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java @@ -168,6 +168,11 @@ public final class ConstantScoreQuery extends Query { public boolean isCacheable(LeafReaderContext ctx) { return innerWeight.isCacheable(ctx); } + + @Override + public int count(LeafReaderContext context) throws IOException { + return innerWeight.count(context); + } }; } else { return innerWeight; diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java index d7937a5cba9..d771cd9ec57 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java @@ -402,49 +402,57 @@ public class IndexSearcher { return similarity; } + private static class ShortcutHitCountCollector implements Collector { + private final Weight weight; + private final TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); + private int weightCount; + + ShortcutHitCountCollector(Weight weight) { + this.weight = weight; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + int count = weight.count(context); + // check if the number of hits can be computed in constant time + if (count == -1) { + // use a TotalHitCountCollector to calculate the number of hits in the usual way + return totalHitCountCollector.getLeafCollector(context); + } else { + weightCount += count; + throw new CollectionTerminatedException(); + } + } + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; + } + } + /** Count how many documents match the given query. */ public int count(Query query) throws IOException { query = rewrite(query); - while (true) { - // remove wrappers that don't matter for counts - if (query instanceof ConstantScoreQuery) { - query = ((ConstantScoreQuery) query).getQuery(); - } else { - break; - } - } - - // some counts can be computed in constant time - if (query instanceof MatchAllDocsQuery) { - return reader.numDocs(); - } else if (query instanceof TermQuery && reader.hasDeletions() == false) { - Term term = ((TermQuery) query).getTerm(); - int count = 0; - for (LeafReaderContext leaf : reader.leaves()) { - count += leaf.reader().docFreq(term); - } - return count; - } - - // general case: create a collector and count matches - final CollectorManager collectorManager = - new CollectorManager() { + final Weight weight = createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1); + final CollectorManager shortcutCollectorManager = + new CollectorManager() { @Override - public TotalHitCountCollector newCollector() throws IOException { - return new TotalHitCountCollector(); + public ShortcutHitCountCollector newCollector() throws IOException { + return new ShortcutHitCountCollector(weight); } @Override - public Integer reduce(Collection collectors) throws IOException { - int total = 0; - for (TotalHitCountCollector collector : collectors) { - total += collector.getTotalHits(); + public Integer reduce(Collection collectors) + throws IOException { + int totalHitCount = 0; + for (ShortcutHitCountCollector c : collectors) { + totalHitCount += c.weightCount + c.totalHitCountCollector.getTotalHits(); } - return total; + return totalHitCount; } }; - return search(query, collectorManager); + return search(weight, shortcutCollectorManager, new ShortcutHitCountCollector(weight)); } /** @@ -659,29 +667,29 @@ public class IndexSearcher { */ public T search(Query query, CollectorManager collectorManager) throws IOException { + final C firstCollector = collectorManager.newCollector(); + query = rewrite(query); + final Weight weight = createWeight(query, firstCollector.scoreMode(), 1); + return search(weight, collectorManager, firstCollector); + } + + private T search( + Weight weight, CollectorManager collectorManager, C firstCollector) throws IOException { if (executor == null || leafSlices.length <= 1) { - final C collector = collectorManager.newCollector(); - search(query, collector); - return collectorManager.reduce(Collections.singletonList(collector)); + search(leafContexts, weight, firstCollector); + return collectorManager.reduce(Collections.singletonList(firstCollector)); } else { final List collectors = new ArrayList<>(leafSlices.length); - ScoreMode scoreMode = null; - for (int i = 0; i < leafSlices.length; ++i) { + collectors.add(firstCollector); + final ScoreMode scoreMode = firstCollector.scoreMode(); + for (int i = 1; i < leafSlices.length; ++i) { final C collector = collectorManager.newCollector(); collectors.add(collector); - if (scoreMode == null) { - scoreMode = collector.scoreMode(); - } else if (scoreMode != collector.scoreMode()) { + if (scoreMode != collector.scoreMode()) { throw new IllegalStateException( "CollectorManager does not always produce collectors with the same score mode"); } } - if (scoreMode == null) { - // no segments - scoreMode = ScoreMode.COMPLETE; - } - query = rewrite(query); - final Weight weight = createWeight(query, scoreMode, 1); final List> listTasks = new ArrayList<>(); for (int i = 0; i < leafSlices.length; ++i) { final LeafReaderContext[] leaves = leafSlices[i].leaves; diff --git a/lucene/core/src/java/org/apache/lucene/search/MatchAllDocsQuery.java b/lucene/core/src/java/org/apache/lucene/search/MatchAllDocsQuery.java index 5f6baad1a1d..e13521cc8d4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MatchAllDocsQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/MatchAllDocsQuery.java @@ -72,6 +72,11 @@ public final class MatchAllDocsQuery extends Query { } }; } + + @Override + public int count(LeafReaderContext context) { + return context.reader().numDocs(); + } }; } diff --git a/lucene/core/src/java/org/apache/lucene/search/MatchNoDocsQuery.java b/lucene/core/src/java/org/apache/lucene/search/MatchNoDocsQuery.java index dd731680297..b05697791a8 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MatchNoDocsQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/MatchNoDocsQuery.java @@ -52,6 +52,11 @@ public class MatchNoDocsQuery extends Query { public boolean isCacheable(LeafReaderContext ctx) { return true; } + + @Override + public int count(LeafReaderContext context) { + return 0; + } }; } diff --git a/lucene/core/src/java/org/apache/lucene/search/TermQuery.java b/lucene/core/src/java/org/apache/lucene/search/TermQuery.java index 3f52eb3573d..f08855abe5c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TermQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TermQuery.java @@ -179,6 +179,22 @@ public class TermQuery extends Query { } return Explanation.noMatch("no matching term"); } + + @Override + public int count(LeafReaderContext context) throws IOException { + if (context.reader().hasDeletions() == false) { + TermsEnum termsEnum = getTermsEnum(context); + // termsEnum is not null if term state is available + if (termsEnum != null) { + return termsEnum.docFreq(); + } else { + // the term cannot be found in the dictionary so the count is 0 + return 0; + } + } else { + return super.count(context); + } + } } /** Constructs a query for the term t. */ diff --git a/lucene/core/src/java/org/apache/lucene/search/Weight.java b/lucene/core/src/java/org/apache/lucene/search/Weight.java index df1fa34a748..297543580ea 100644 --- a/lucene/core/src/java/org/apache/lucene/search/Weight.java +++ b/lucene/core/src/java/org/apache/lucene/search/Weight.java @@ -174,6 +174,27 @@ public abstract class Weight implements SegmentCacheable { return new DefaultBulkScorer(scorer); } + /** + * Counts the number of live documents that match a given {@link Weight#parentQuery} in a leaf. + * + *

The default implementation returns -1 for every query. This indicates that the count could + * not be computed in O(1) time. + * + *

Specific query classes should override it to provide other accurate O(1) implementations + * (that actually return the count). Look at {@link MatchAllDocsQuery#createWeight(IndexSearcher, + * ScoreMode, float)} for an example + * + *

We use this property of the function to to count hits in {@link IndexSearcher#count(Query)}. + * + * @param context the {@link org.apache.lucene.index.LeafReaderContext} for which to return the + * count. + * @return integer count of the number of matches + * @throws IOException if there is a low-level I/O error + */ + public int count(LeafReaderContext context) throws IOException { + return -1; + } + /** * Just wraps a Scorer and performs top scoring using it. * diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java index 282f68f150c..a24fac012e6 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java @@ -41,6 +41,7 @@ import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.store.Directory; +import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.NamedThreadFactory; import org.apache.lucene.util.TestUtil; @@ -736,6 +737,45 @@ public class TestBooleanQuery extends LuceneTestCase { dir.close(); } + // LUCENE-9620 Add Weight#count(LeafReaderContext) + public void testQueryMatchesCount() throws IOException { + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + int randomNumDocs = random().nextInt(500); + int numMatchingDocs = 0; + + for (int i = 0; i < randomNumDocs; i++) { + Document doc = new Document(); + Field f; + if (random().nextBoolean()) { + f = newTextField("field", "a b c " + random().nextInt(), Field.Store.NO); + numMatchingDocs++; + } else { + f = newTextField("field", String.valueOf(random().nextInt()), Field.Store.NO); + } + doc.add(f); + w.addDocument(doc); + } + w.commit(); + + DirectoryReader reader = w.getReader(); + final IndexSearcher searcher = new IndexSearcher(reader); + + BooleanQuery.Builder q = new BooleanQuery.Builder(); + q.add(new PhraseQuery("field", "a", "b"), Occur.SHOULD); + q.add(new TermQuery(new Term("field", "c")), Occur.SHOULD); + + Query builtQuery = q.build(); + + assertEquals(searcher.count(builtQuery), numMatchingDocs); + final Weight weight = searcher.createWeight(builtQuery, ScoreMode.COMPLETE, 1); + // tests that the Weight#count API returns -1 instead of returning the total number of matches + assertEquals(weight.count(reader.leaves().get(0)), -1); + + IOUtils.close(reader, w, dir); + } + public void testToString() { BooleanQuery.Builder bq = new BooleanQuery.Builder(); bq.add(new TermQuery(new Term("field", "a")), Occur.SHOULD); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestFilterWeight.java b/lucene/core/src/test/org/apache/lucene/search/TestFilterWeight.java index 34720261b6e..2d817e1024c 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestFilterWeight.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestFilterWeight.java @@ -36,7 +36,8 @@ public class TestFilterWeight extends LuceneTestCase { final int modifiers = superClassMethod.getModifiers(); if (Modifier.isFinal(modifiers)) continue; if (Modifier.isStatic(modifiers)) continue; - if (Arrays.asList("bulkScorer", "scorerSupplier").contains(superClassMethod.getName())) { + if (Arrays.asList("bulkScorer", "scorerSupplier", "count") + .contains(superClassMethod.getName())) { try { final Method subClassMethod = subClass.getDeclaredMethod( diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java index 32437a17a48..0cde70f8d09 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java @@ -103,6 +103,35 @@ public class TestTermQuery extends LuceneTestCase { IOUtils.close(reader, w, dir); } + // LUCENE-9620 Add Weight#count(LeafReaderContext) + public void testQueryMatchesCount() throws IOException { + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + int randomNumDocs = random().nextInt(500); + int numMatchingDocs = 0; + + for (int i = 0; i < randomNumDocs; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + doc.add(new StringField("foo", "bar", Store.NO)); + numMatchingDocs++; + } + w.addDocument(doc); + } + w.commit(); + + DirectoryReader reader = w.getReader(); + final IndexSearcher searcher = new IndexSearcher(reader); + + Query testQuery = new TermQuery(new Term("foo", "bar")); + assertEquals(searcher.count(testQuery), numMatchingDocs); + final Weight weight = searcher.createWeight(testQuery, ScoreMode.COMPLETE, 1); + assertEquals(weight.count(reader.leaves().get(0)), numMatchingDocs); + + IOUtils.close(reader, w, dir); + } + public void testGetTermStates() throws Exception { // no term states: