LUCENE-10620: Pass the Weight to Collectors. (#964)

This allows `Collector`s to use `Weight#count` when appropriate.
This commit is contained in:
Adrien Grand 2022-06-23 17:56:15 +02:00 committed by GitHub
parent 2da9951f23
commit 4c1ae2a332
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 197 additions and 85 deletions

View File

@ -56,4 +56,11 @@ public interface Collector {
/** Indicates what features are required from the scorer. */
ScoreMode scoreMode();
/**
* Set the {@link Weight} that will be used to produce scorers that will feed {@link
* LeafCollector}s. This is typically useful to have access to {@link Weight#count} from {@link
* Collector#getLeafCollector}.
*/
default void setWeight(Weight weight) {}
}

View File

@ -38,6 +38,11 @@ public abstract class FilterCollector implements Collector {
return in.getLeafCollector(context);
}
@Override
public void setWeight(Weight weight) {
in.setWeight(weight);
}
@Override
public String toString() {
return getClass().getSimpleName() + "(" + in + ")";

View File

@ -412,61 +412,13 @@ public class IndexSearcher {
return similarity;
}
private static class ShortcutHitCountCollector implements Collector {
private final Weight weight;
private final TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
private int weightCount;
ShortcutHitCountCollector(Weight weight) {
this.weight = weight;
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
int count = weight.count(context);
// check if the number of hits can be computed in constant time
if (count == -1) {
// use a TotalHitCountCollector to calculate the number of hits in the usual way
return totalHitCountCollector.getLeafCollector(context);
} else {
weightCount += count;
throw new CollectionTerminatedException();
}
}
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}
}
/**
* Count how many documents match the given query. May be faster than counting number of hits by
* collecting all matches, as the number of hits is retrieved from the index statistics when
* possible.
*/
public int count(Query query) throws IOException {
query = rewrite(query, false);
final Weight weight = createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1);
final CollectorManager<ShortcutHitCountCollector, Integer> shortcutCollectorManager =
new CollectorManager<ShortcutHitCountCollector, Integer>() {
@Override
public ShortcutHitCountCollector newCollector() throws IOException {
return new ShortcutHitCountCollector(weight);
}
@Override
public Integer reduce(Collection<ShortcutHitCountCollector> collectors)
throws IOException {
int totalHitCount = 0;
for (ShortcutHitCountCollector c : collectors) {
totalHitCount += c.weightCount + c.totalHitCountCollector.getTotalHits();
}
return totalHitCount;
}
};
return search(weight, shortcutCollectorManager, new ShortcutHitCountCollector(weight));
return search(new ConstantScoreQuery(query), new TotalHitCountCollectorManager());
}
/**
@ -750,6 +702,8 @@ public class IndexSearcher {
protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector)
throws IOException {
collector.setWeight(weight);
// TODO: should we make this
// threaded...? the Collector could be sync'd?
// always use single thread:

View File

@ -149,6 +149,13 @@ public class MultiCollector implements Collector {
}
}
@Override
public void setWeight(Weight weight) {
for (Collector collector : collectors) {
collector.setWeight(weight);
}
}
/** Provides access to the wrapped {@code Collector}s for advanced use-cases */
public Collector[] getCollectors() {
return collectors;

View File

@ -16,13 +16,16 @@
*/
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
/**
* Just counts the total number of hits. For cases when this is the only collector used, {@link
* IndexSearcher#count(Query)} should be called instead of {@link IndexSearcher#search(Query,
* Collector)} as the former is faster whenever the count can be returned directly from the index
* statistics.
* Just counts the total number of hits. This is the collector behind {@link IndexSearcher#count}.
* When the {@link Weight} implements {@link Weight#count}, this collector will skip collecting
* segments.
*/
public class TotalHitCountCollector extends SimpleCollector {
public class TotalHitCountCollector implements Collector {
private Weight weight;
private int totalHits;
/** Returns how many hits matched the search. */
@ -30,13 +33,32 @@ public class TotalHitCountCollector extends SimpleCollector {
return totalHits;
}
@Override
public void collect(int doc) {
totalHits++;
}
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}
@Override
public void setWeight(Weight weight) {
this.weight = weight;
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
int leafCount = weight == null ? -1 : weight.count(context);
if (leafCount != -1) {
totalHits += leafCount;
throw new CollectionTerminatedException();
}
return new LeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {}
@Override
public void collect(int doc) throws IOException {
totalHits++;
}
};
}
}

View File

@ -45,6 +45,7 @@ import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.search.FixedBitSetCollector;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
@ -1021,7 +1022,7 @@ public class TestBooleanQuery extends LuceneTestCase {
builder.setMinimumNumberShouldMatch(TestUtil.nextInt(random(), 0, numShouldClauses));
Query booleanQuery = builder.build();
assertEquals(
(int) searcher.search(booleanQuery, new TotalHitCountCollectorManager()),
(int) searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager()),
searcher.count(booleanQuery));
}
reader.close();

View File

@ -64,6 +64,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.AssertingIndexSearcher;
import org.apache.lucene.tests.search.CheckHits;
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.RamUsageTester;
import org.apache.lucene.tests.util.TestUtil;
@ -168,8 +169,8 @@ public class TestLRUQueryCache extends LuceneTestCase {
RandomPicks.randomFrom(
random(), new String[] {"blue", "red", "yellow", "green"});
final Query q = new TermQuery(new Term("color", value));
TotalHitCountCollectorManager collectorManager =
new TotalHitCountCollectorManager();
CollectorManager<DummyTotalHitCountCollector, Integer> collectorManager =
DummyTotalHitCountCollector.createManager();
// will use the cache
final int totalHits1 = searcher.search(q, collectorManager);
final long totalHits2 =
@ -177,8 +178,8 @@ public class TestLRUQueryCache extends LuceneTestCase {
q,
new CollectorManager<FilterCollector, Integer>() {
@Override
public FilterCollector newCollector() {
return new FilterCollector(new TotalHitCountCollector()) {
public FilterCollector newCollector() throws IOException {
return new FilterCollector(collectorManager.newCollector()) {
@Override
public ScoreMode scoreMode() {
// will not use the cache because of scores
@ -194,7 +195,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
collectors.stream()
.map(
filterCollector ->
(TotalHitCountCollector) filterCollector.in)
(DummyTotalHitCountCollector) filterCollector.in)
.collect(Collectors.toList()));
}
});
@ -963,7 +964,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
searcher.setQueryCache(queryCache);
searcher.setQueryCachingPolicy(policy);
searcher.search(query.build(), new TotalHitCountCollectorManager());
searcher.search(query.build(), DummyTotalHitCountCollector.createManager());
reader.close();
dir.close();
@ -1187,12 +1188,12 @@ public class TestLRUQueryCache extends LuceneTestCase {
searcher.setQueryCachingPolicy(ALWAYS_CACHE);
BadQuery query = new BadQuery();
searcher.search(query, new TotalHitCountCollectorManager());
searcher.search(query, DummyTotalHitCountCollector.createManager());
query.i[0] += 1; // change the hashCode!
try {
// trigger an eviction
searcher.search(new MatchAllDocsQuery(), new TotalHitCountCollectorManager());
searcher.search(new MatchAllDocsQuery(), DummyTotalHitCountCollector.createManager());
fail();
} catch (
@SuppressWarnings("unused")
@ -1273,7 +1274,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
query.add(bar, Occur.FILTER);
query.add(foo, Occur.FILTER);
}
indexSearcher.search(query.build(), new TotalHitCountCollectorManager());
indexSearcher.search(query.build(), DummyTotalHitCountCollector.createManager());
assertEquals(1, policy.frequency(query.build()));
assertEquals(1, policy.frequency(foo));
assertEquals(1, policy.frequency(bar));

View File

@ -32,6 +32,7 @@ import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.junit.Test;
@ -101,13 +102,13 @@ public class TestMultiCollector extends LuceneTestCase {
final IndexReader reader = w.getReader();
w.close();
final IndexSearcher searcher = newSearcher(reader, true, true, false);
Map<TotalHitCountCollector, Integer> expectedCounts = new HashMap<>();
Map<DummyTotalHitCountCollector, Integer> expectedCounts = new HashMap<>();
List<Collector> collectors = new ArrayList<>();
final int numCollectors = TestUtil.nextInt(random(), 1, 5);
for (int i = 0; i < numCollectors; ++i) {
final int terminateAfter = random().nextInt(numDocs + 10);
final int expectedCount = terminateAfter > numDocs ? numDocs : terminateAfter;
TotalHitCountCollector collector = new TotalHitCountCollector();
DummyTotalHitCountCollector collector = new DummyTotalHitCountCollector();
expectedCounts.put(collector, expectedCount);
collectors.add(new TerminateAfterCollector(collector, terminateAfter));
}
@ -124,7 +125,8 @@ public class TestMultiCollector extends LuceneTestCase {
return null;
}
});
for (Map.Entry<TotalHitCountCollector, Integer> expectedCount : expectedCounts.entrySet()) {
for (Map.Entry<DummyTotalHitCountCollector, Integer> expectedCount :
expectedCounts.entrySet()) {
assertEquals(expectedCount.getValue().intValue(), expectedCount.getKey().getTotalHits());
}
reader.close();
@ -133,8 +135,8 @@ public class TestMultiCollector extends LuceneTestCase {
}
public void testSetScorerAfterCollectionTerminated() throws IOException {
Collector collector1 = new TotalHitCountCollector();
Collector collector2 = new TotalHitCountCollector();
Collector collector1 = new DummyTotalHitCountCollector();
Collector collector2 = new DummyTotalHitCountCollector();
AtomicBoolean setScorerCalled1 = new AtomicBoolean();
collector1 = new SetScorerCollector(collector1, setScorerCalled1);
@ -224,7 +226,7 @@ public class TestMultiCollector extends LuceneTestCase {
scorer.setMinCompetitiveScore(minScore);
}
};
Collector multiCollector = MultiCollector.wrap(collector, new TotalHitCountCollector());
Collector multiCollector = MultiCollector.wrap(collector, new DummyTotalHitCountCollector());
LeafCollector leafCollector = multiCollector.getLeafCollector(reader.leaves().get(0));
leafCollector.setScorer(scorer);
leafCollector.collect(0); // no exception
@ -283,7 +285,7 @@ public class TestMultiCollector extends LuceneTestCase {
List<Collector> cols = new ArrayList<>();
cols.add(collector);
for (int col = 0; col < numCol; col++) {
cols.add(new TerminateAfterCollector(new TotalHitCountCollector(), 0));
cols.add(new TerminateAfterCollector(new DummyTotalHitCountCollector(), 0));
}
Collections.shuffle(cols, random());
Collector multiCollector = MultiCollector.wrap(cols);

View File

@ -24,6 +24,7 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestSearchWithThreads extends LuceneTestCase {
@ -57,7 +58,7 @@ public class TestSearchWithThreads extends LuceneTestCase {
final AtomicBoolean failed = new AtomicBoolean();
final AtomicLong netSearch = new AtomicLong();
TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
CollectorManager<?, Integer> collectorManager = DummyTotalHitCountCollector.createManager();
Thread[] threads = new Thread[numThreads];
for (int threadID = 0; threadID < numThreads; threadID++) {
threads[threadID] =

View File

@ -34,6 +34,7 @@ import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@ -91,14 +92,13 @@ public class TestTermQuery extends LuceneTestCase {
IndexSearcher searcher = new IndexSearcher(reader);
// use a collector rather than searcher.count() which would just read the
// doc freq instead of creating a scorer
TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
int totalHits = searcher.search(query, collectorManager);
int totalHits = searcher.search(query, DummyTotalHitCountCollector.createManager());
assertEquals(1, totalHits);
TermQuery queryWithContext =
new TermQuery(
new Term("foo", "bar"),
TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
totalHits = searcher.search(queryWithContext, collectorManager);
totalHits = searcher.search(queryWithContext, DummyTotalHitCountCollector.createManager());
assertEquals(1, totalHits);
IOUtils.close(reader, w, dir);

View File

@ -20,6 +20,8 @@ import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
@ -42,6 +44,15 @@ public class TestTotalHitCountCollector extends LuceneTestCase {
TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
int totalHits = searcher.search(new MatchAllDocsQuery(), collectorManager);
assertEquals(5, totalHits);
Query query =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("string", "a1")), Occur.SHOULD)
.add(new TermQuery(new Term("string", "b3")), Occur.SHOULD)
.build();
totalHits = searcher.search(query, collectorManager);
assertEquals(2, totalHits);
reader.close();
indexStore.close();
}

View File

@ -24,6 +24,7 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
/**
* This class wraps a Collector and times the execution of: - setScorer() - collect() -
@ -83,6 +84,11 @@ public class ProfilerCollector implements Collector {
return collector.getLeafCollector(context);
}
@Override
public void setWeight(Weight weight) {
collector.setWeight(weight);
}
@Override
public ScoreMode scoreMode() {
return collector.scoreMode();

View File

@ -40,11 +40,11 @@ import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.SortedNumericSortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollectorManager;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@ -221,7 +221,8 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas
private static void assertNumberOfHits(IndexSearcher searcher, Query query, int numberOfHits)
throws IOException {
assertEquals(
numberOfHits, searcher.search(query, new TotalHitCountCollectorManager()).intValue());
numberOfHits,
searcher.search(query, DummyTotalHitCountCollector.createManager()).intValue());
assertEquals(numberOfHits, searcher.count(query));
}

View File

@ -37,9 +37,9 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollectorManager;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@ -808,8 +808,8 @@ public class TestMultiRangeQueries extends LuceneTestCase {
MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader);
BooleanQuery booleanQuery = builder2.build();
int count = searcher.search(multiRangeQuery, new TotalHitCountCollectorManager());
int booleanCount = searcher.search(booleanQuery, new TotalHitCountCollectorManager());
int count = searcher.search(multiRangeQuery, DummyTotalHitCountCollector.createManager());
int booleanCount = searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager());
assertEquals(booleanCount, count);
}
IOUtils.close(reader, w, dir);

View File

@ -22,10 +22,12 @@ import org.apache.lucene.search.Collector;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FilterCollector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Weight;
/** A collector that asserts that it is used correctly. */
class AssertingCollector extends FilterCollector {
private boolean weightSet = false;
private int maxDoc = -1;
private int previousLeafMaxDoc = 0;
@ -43,6 +45,7 @@ class AssertingCollector extends FilterCollector {
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
assert weightSet : "Set the weight first";
assert context.docBase >= previousLeafMaxDoc;
previousLeafMaxDoc = context.docBase + context.reader().maxDoc();
@ -65,4 +68,12 @@ class AssertingCollector extends FilterCollector {
}
};
}
@Override
public void setWeight(Weight weight) {
assert weightSet == false : "Weight set twice";
weightSet = true;
assert weight != null;
in.setWeight(weight);
}
}

View File

@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.tests.search;
import java.io.IOException;
import java.util.Collection;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.Weight;
/**
* A dummy version of {@link TotalHitCountCollector} that doesn't shortcut using {@link
* Weight#count}.
*/
public class DummyTotalHitCountCollector implements Collector {
private int totalHits;
/** Constructor */
public DummyTotalHitCountCollector() {}
/** Get the number of hits. */
public int getTotalHits() {
return totalHits;
}
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
return new LeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {}
@Override
public void collect(int doc) throws IOException {
totalHits++;
}
};
}
/** Create a collector manager. */
public static CollectorManager<DummyTotalHitCountCollector, Integer> createManager() {
return new CollectorManager<DummyTotalHitCountCollector, Integer>() {
@Override
public DummyTotalHitCountCollector newCollector() throws IOException {
return new DummyTotalHitCountCollector();
}
@Override
public Integer reduce(Collection<DummyTotalHitCountCollector> collectors) throws IOException {
int sum = 0;
for (DummyTotalHitCountCollector coll : collectors) {
sum += coll.totalHits;
}
return sum;
}
};
}
}