LUCENE-10002: Replace simple usages of TotalHitCountCollector with IndexSearcher#count (#612)

In case only number of documents are collected, IndexSearcher#search(Query, Collector) is commonly used, which does not use the executor that's been eventually set to the searcher. Calling `IndexSearcher#count(Query)` makes the code more concise and is also more correct as it honours the executor that's been set to the searcher instance.

Co-authored-by: Adrien Grand <jpountz@gmail.com>
This commit is contained in:
Luca Cavanna 2022-01-25 16:11:19 +01:00 committed by GitHub
parent fd817b6fb1
commit 11006fba59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 18 additions and 28 deletions

View File

@ -32,7 +32,6 @@ import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
/** /**
@ -179,10 +178,8 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
if (query != null) { if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST); booleanQuery.add(query, BooleanClause.Occur.MUST);
} }
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
indexSearcher.search(booleanQuery.build(), totalHitCountCollector);
int ret = totalHitCountCollector.getTotalHits(); int ret = indexSearcher.count(booleanQuery.build());
if (ret != 0) { if (ret != 0) {
searched.put(cclass, ret); searched.put(cclass, ret);
} }

View File

@ -35,7 +35,6 @@ import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery; import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
@ -169,7 +168,6 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
Terms terms = MultiTerms.getTerms(this.indexReader, this.classFieldName); Terms terms = MultiTerms.getTerms(this.indexReader, this.classFieldName);
int docCount; int docCount;
if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
BooleanQuery.Builder q = new BooleanQuery.Builder(); BooleanQuery.Builder q = new BooleanQuery.Builder();
q.add( q.add(
new BooleanClause( new BooleanClause(
@ -179,8 +177,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
if (query != null) { if (query != null) {
q.add(query, BooleanClause.Occur.MUST); q.add(query, BooleanClause.Occur.MUST);
} }
indexSearcher.search(q.build(), classQueryCountCollector); docCount = indexSearcher.count(q.build());
docCount = classQueryCountCollector.getTotalHits();
} else { } else {
docCount = terms.getDocCount(); docCount = terms.getDocCount();
} }
@ -276,9 +273,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
if (query != null) { if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST); booleanQuery.add(query, BooleanClause.Occur.MUST);
} }
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); return indexSearcher.count(booleanQuery.build());
indexSearcher.search(booleanQuery.build(), totalHitCountCollector);
return totalHitCountCollector.getTotalHits();
} }
private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException { private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {

View File

@ -40,7 +40,6 @@ import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
/** /**
@ -263,9 +262,7 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
if (query != null) { if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST); booleanQuery.add(query, BooleanClause.Occur.MUST);
} }
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); return indexSearcher.count(booleanQuery.build());
indexSearcher.search(booleanQuery.build(), totalHitCountCollector);
return totalHitCountCollector.getTotalHits();
} }
private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException { private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {

View File

@ -16,7 +16,12 @@
*/ */
package org.apache.lucene.search; package org.apache.lucene.search;
/** Just counts the total number of hits. */ /**
* Just counts the total number of hits. For cases when this is the only collector used, {@link
* IndexSearcher#count(Query)} should be called instead of {@link IndexSearcher#search(Query,
* Collector)} as the former is faster whenever the count can be returned directly from the index
* statistics.
*/
public class TotalHitCountCollector extends SimpleCollector { public class TotalHitCountCollector extends SimpleCollector {
private int totalHits; private int totalHits;

View File

@ -62,9 +62,8 @@ public class TestFuzzyTermOnShortTerms extends LuceneTestCase {
Directory d = getDirectory(analyzer, docs); Directory d = getDirectory(analyzer, docs);
IndexReader r = DirectoryReader.open(d); IndexReader r = DirectoryReader.open(d);
IndexSearcher s = new IndexSearcher(r); IndexSearcher s = new IndexSearcher(r);
TotalHitCountCollector c = new TotalHitCountCollector(); int totalHits = s.count(q);
s.search(q, c); assertEquals(q.toString(), expected, totalHits);
assertEquals(q.toString(), expected, c.getTotalHits());
r.close(); r.close();
d.close(); d.close();
} }

View File

@ -1169,7 +1169,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
searcher.setQueryCachingPolicy(ALWAYS_CACHE); searcher.setQueryCachingPolicy(ALWAYS_CACHE);
BadQuery query = new BadQuery(); BadQuery query = new BadQuery();
searcher.count(query); searcher.search(query, new TotalHitCountCollector());
query.i[0] += 1; // change the hashCode! query.i[0] += 1; // change the hashCode!
try { try {

View File

@ -681,15 +681,14 @@ public class TestJoinUtil extends LuceneTestCase {
Query joinQuery = Query joinQuery =
JoinUtil.createJoinQuery( JoinUtil.createJoinQuery(
"join_field", fromQuery, toQuery, searcher, scoreMode, ordinalMap, min, max); "join_field", fromQuery, toQuery, searcher, scoreMode, ordinalMap, min, max);
TotalHitCountCollector collector = new TotalHitCountCollector(); int totalHits = searcher.count(joinQuery);
searcher.search(joinQuery, collector);
int expectedCount = 0; int expectedCount = 0;
for (int numChildDocs : childDocsPerParent) { for (int numChildDocs : childDocsPerParent) {
if (numChildDocs >= min && numChildDocs <= max) { if (numChildDocs >= min && numChildDocs <= max) {
expectedCount++; expectedCount++;
} }
} }
assertEquals(expectedCount, collector.getTotalHits()); assertEquals(expectedCount, totalHits);
} }
searcher.getIndexReader().close(); searcher.getIndexReader().close();
dir.close(); dir.close();

View File

@ -24,7 +24,6 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.spatial.StrategyTestCase; import org.apache.lucene.spatial.StrategyTestCase;
import org.apache.lucene.spatial.prefix.tree.QuadPrefixTree; import org.apache.lucene.spatial.prefix.tree.QuadPrefixTree;
import org.apache.lucene.spatial.prefix.tree.SpatialPrefixTree; import org.apache.lucene.spatial.prefix.tree.SpatialPrefixTree;
@ -282,13 +281,12 @@ public class TestHeatmapFacetCounter extends StrategyTestCase {
Query filter = Query filter =
new IntersectsPrefixTreeQuery( new IntersectsPrefixTreeQuery(
pt, strategy.getFieldName(), grid, facetLevel, grid.getMaxLevels()); pt, strategy.getFieldName(), grid, facetLevel, grid.getMaxLevels());
final TotalHitCountCollector collector = new TotalHitCountCollector(); int totalHits = indexSearcher.count(filter);
indexSearcher.search(filter, collector);
cellsValidated++; cellsValidated++;
if (collector.getTotalHits() > 0) { if (totalHits > 0) {
cellValidatedNonZero++; cellValidatedNonZero++;
} }
return collector.getTotalHits(); return totalHits;
} }
private Shape randomIndexedShape() { private Shape randomIndexedShape() {