LUCENE-9620 Add Weight#count(LeafReaderContext) (#242)

Add a default implementation in Weight.java and add sample faster
implementations in MatchAllDocsQuery, MatchNoDocsQuery, TermQuery

Add tests for BooleanQuery and TermQuery

Co-authored-by: Gautam Worah <gauworah@amazon.com>
Co-authored-by: Adrien Grand <jpountz@gmail.com>
This commit is contained in:
Gautam Worah 2021-09-03 00:09:38 -07:00 committed by GitHub
parent 059d06cec7
commit 44e9f5de53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 181 additions and 47 deletions

View File

@ -143,6 +143,10 @@ API Changes
* LUCENE-10027: Directory reader open API from indexCommit and leafSorter has been modified
to add an extra parameter - minSupportedMajorVersion. (Mayya Sharipova)
* LUCENE-9620: Added a (sometimes) faster implementation for IndexSearcher#count that relies on the new Weight#count API.
The Weight#count API represents a cleaner way for Query classes to optimize their counting method.
(Gautam Worah, Adrien Grand)
Improvements
* LUCENE-9960: Avoid unnecessary top element replacement for equal elements in PriorityQueue. (Dawid Weiss)

View File

@ -168,6 +168,11 @@ public final class ConstantScoreQuery extends Query {
public boolean isCacheable(LeafReaderContext ctx) {
return innerWeight.isCacheable(ctx);
}
@Override
public int count(LeafReaderContext context) throws IOException {
return innerWeight.count(context);
}
};
} else {
return innerWeight;

View File

@ -402,49 +402,57 @@ public class IndexSearcher {
return similarity;
}
private static class ShortcutHitCountCollector implements Collector {
private final Weight weight;
private final TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
private int weightCount;
ShortcutHitCountCollector(Weight weight) {
this.weight = weight;
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
int count = weight.count(context);
// check if the number of hits can be computed in constant time
if (count == -1) {
// use a TotalHitCountCollector to calculate the number of hits in the usual way
return totalHitCountCollector.getLeafCollector(context);
} else {
weightCount += count;
throw new CollectionTerminatedException();
}
}
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}
}
/** Count how many documents match the given query. */
public int count(Query query) throws IOException {
query = rewrite(query);
while (true) {
// remove wrappers that don't matter for counts
if (query instanceof ConstantScoreQuery) {
query = ((ConstantScoreQuery) query).getQuery();
} else {
break;
}
}
// some counts can be computed in constant time
if (query instanceof MatchAllDocsQuery) {
return reader.numDocs();
} else if (query instanceof TermQuery && reader.hasDeletions() == false) {
Term term = ((TermQuery) query).getTerm();
int count = 0;
for (LeafReaderContext leaf : reader.leaves()) {
count += leaf.reader().docFreq(term);
}
return count;
}
// general case: create a collector and count matches
final CollectorManager<TotalHitCountCollector, Integer> collectorManager =
new CollectorManager<TotalHitCountCollector, Integer>() {
final Weight weight = createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1);
final CollectorManager<ShortcutHitCountCollector, Integer> shortcutCollectorManager =
new CollectorManager<ShortcutHitCountCollector, Integer>() {
@Override
public TotalHitCountCollector newCollector() throws IOException {
return new TotalHitCountCollector();
public ShortcutHitCountCollector newCollector() throws IOException {
return new ShortcutHitCountCollector(weight);
}
@Override
public Integer reduce(Collection<TotalHitCountCollector> collectors) throws IOException {
int total = 0;
for (TotalHitCountCollector collector : collectors) {
total += collector.getTotalHits();
public Integer reduce(Collection<ShortcutHitCountCollector> collectors)
throws IOException {
int totalHitCount = 0;
for (ShortcutHitCountCollector c : collectors) {
totalHitCount += c.weightCount + c.totalHitCountCollector.getTotalHits();
}
return total;
return totalHitCount;
}
};
return search(query, collectorManager);
return search(weight, shortcutCollectorManager, new ShortcutHitCountCollector(weight));
}
/**
@ -659,29 +667,29 @@ public class IndexSearcher {
*/
public <C extends Collector, T> T search(Query query, CollectorManager<C, T> collectorManager)
throws IOException {
final C firstCollector = collectorManager.newCollector();
query = rewrite(query);
final Weight weight = createWeight(query, firstCollector.scoreMode(), 1);
return search(weight, collectorManager, firstCollector);
}
private <C extends Collector, T> T search(
Weight weight, CollectorManager<C, T> collectorManager, C firstCollector) throws IOException {
if (executor == null || leafSlices.length <= 1) {
final C collector = collectorManager.newCollector();
search(query, collector);
return collectorManager.reduce(Collections.singletonList(collector));
search(leafContexts, weight, firstCollector);
return collectorManager.reduce(Collections.singletonList(firstCollector));
} else {
final List<C> collectors = new ArrayList<>(leafSlices.length);
ScoreMode scoreMode = null;
for (int i = 0; i < leafSlices.length; ++i) {
collectors.add(firstCollector);
final ScoreMode scoreMode = firstCollector.scoreMode();
for (int i = 1; i < leafSlices.length; ++i) {
final C collector = collectorManager.newCollector();
collectors.add(collector);
if (scoreMode == null) {
scoreMode = collector.scoreMode();
} else if (scoreMode != collector.scoreMode()) {
if (scoreMode != collector.scoreMode()) {
throw new IllegalStateException(
"CollectorManager does not always produce collectors with the same score mode");
}
}
if (scoreMode == null) {
// no segments
scoreMode = ScoreMode.COMPLETE;
}
query = rewrite(query);
final Weight weight = createWeight(query, scoreMode, 1);
final List<FutureTask<C>> listTasks = new ArrayList<>();
for (int i = 0; i < leafSlices.length; ++i) {
final LeafReaderContext[] leaves = leafSlices[i].leaves;

View File

@ -72,6 +72,11 @@ public final class MatchAllDocsQuery extends Query {
}
};
}
@Override
public int count(LeafReaderContext context) {
return context.reader().numDocs();
}
};
}

View File

@ -52,6 +52,11 @@ public class MatchNoDocsQuery extends Query {
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
@Override
public int count(LeafReaderContext context) {
return 0;
}
};
}

View File

@ -179,6 +179,22 @@ public class TermQuery extends Query {
}
return Explanation.noMatch("no matching term");
}
@Override
public int count(LeafReaderContext context) throws IOException {
if (context.reader().hasDeletions() == false) {
TermsEnum termsEnum = getTermsEnum(context);
// termsEnum is not null if term state is available
if (termsEnum != null) {
return termsEnum.docFreq();
} else {
// the term cannot be found in the dictionary so the count is 0
return 0;
}
} else {
return super.count(context);
}
}
}
/** Constructs a query for the term <code>t</code>. */

View File

@ -174,6 +174,27 @@ public abstract class Weight implements SegmentCacheable {
return new DefaultBulkScorer(scorer);
}
/**
* Counts the number of live documents that match a given {@link Weight#parentQuery} in a leaf.
*
* <p>The default implementation returns -1 for every query. This indicates that the count could
* not be computed in O(1) time.
*
* <p>Specific query classes should override it to provide other accurate O(1) implementations
* (that actually return the count). Look at {@link MatchAllDocsQuery#createWeight(IndexSearcher,
* ScoreMode, float)} for an example
*
* <p>We use this property of the function to to count hits in {@link IndexSearcher#count(Query)}.
*
* @param context the {@link org.apache.lucene.index.LeafReaderContext} for which to return the
* count.
* @return integer count of the number of matches
* @throws IOException if there is a low-level I/O error
*/
public int count(LeafReaderContext context) throws IOException {
return -1;
}
/**
* Just wraps a Scorer and performs top scoring using it.
*

View File

@ -41,6 +41,7 @@ import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.TestUtil;
@ -736,6 +737,45 @@ public class TestBooleanQuery extends LuceneTestCase {
dir.close();
}
// LUCENE-9620 Add Weight#count(LeafReaderContext)
public void testQueryMatchesCount() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
int randomNumDocs = random().nextInt(500);
int numMatchingDocs = 0;
for (int i = 0; i < randomNumDocs; i++) {
Document doc = new Document();
Field f;
if (random().nextBoolean()) {
f = newTextField("field", "a b c " + random().nextInt(), Field.Store.NO);
numMatchingDocs++;
} else {
f = newTextField("field", String.valueOf(random().nextInt()), Field.Store.NO);
}
doc.add(f);
w.addDocument(doc);
}
w.commit();
DirectoryReader reader = w.getReader();
final IndexSearcher searcher = new IndexSearcher(reader);
BooleanQuery.Builder q = new BooleanQuery.Builder();
q.add(new PhraseQuery("field", "a", "b"), Occur.SHOULD);
q.add(new TermQuery(new Term("field", "c")), Occur.SHOULD);
Query builtQuery = q.build();
assertEquals(searcher.count(builtQuery), numMatchingDocs);
final Weight weight = searcher.createWeight(builtQuery, ScoreMode.COMPLETE, 1);
// tests that the Weight#count API returns -1 instead of returning the total number of matches
assertEquals(weight.count(reader.leaves().get(0)), -1);
IOUtils.close(reader, w, dir);
}
public void testToString() {
BooleanQuery.Builder bq = new BooleanQuery.Builder();
bq.add(new TermQuery(new Term("field", "a")), Occur.SHOULD);

View File

@ -36,7 +36,8 @@ public class TestFilterWeight extends LuceneTestCase {
final int modifiers = superClassMethod.getModifiers();
if (Modifier.isFinal(modifiers)) continue;
if (Modifier.isStatic(modifiers)) continue;
if (Arrays.asList("bulkScorer", "scorerSupplier").contains(superClassMethod.getName())) {
if (Arrays.asList("bulkScorer", "scorerSupplier", "count")
.contains(superClassMethod.getName())) {
try {
final Method subClassMethod =
subClass.getDeclaredMethod(

View File

@ -103,6 +103,35 @@ public class TestTermQuery extends LuceneTestCase {
IOUtils.close(reader, w, dir);
}
// LUCENE-9620 Add Weight#count(LeafReaderContext)
public void testQueryMatchesCount() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
int randomNumDocs = random().nextInt(500);
int numMatchingDocs = 0;
for (int i = 0; i < randomNumDocs; i++) {
Document doc = new Document();
if (random().nextBoolean()) {
doc.add(new StringField("foo", "bar", Store.NO));
numMatchingDocs++;
}
w.addDocument(doc);
}
w.commit();
DirectoryReader reader = w.getReader();
final IndexSearcher searcher = new IndexSearcher(reader);
Query testQuery = new TermQuery(new Term("foo", "bar"));
assertEquals(searcher.count(testQuery), numMatchingDocs);
final Weight weight = searcher.createWeight(testQuery, ScoreMode.COMPLETE, 1);
assertEquals(weight.count(reader.leaves().get(0)), numMatchingDocs);
IOUtils.close(reader, w, dir);
}
public void testGetTermStates() throws Exception {
// no term states: