mirror of https://github.com/apache/lucene.git
LUCENE-10620: Pass the Weight to Collectors. (#964)
This allows `Collector`s to use `Weight#count` when appropriate.
This commit is contained in:
parent
2da9951f23
commit
4c1ae2a332
|
@ -56,4 +56,11 @@ public interface Collector {
|
||||||
|
|
||||||
/** Indicates what features are required from the scorer. */
|
/** Indicates what features are required from the scorer. */
|
||||||
ScoreMode scoreMode();
|
ScoreMode scoreMode();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the {@link Weight} that will be used to produce scorers that will feed {@link
|
||||||
|
* LeafCollector}s. This is typically useful to have access to {@link Weight#count} from {@link
|
||||||
|
* Collector#getLeafCollector}.
|
||||||
|
*/
|
||||||
|
default void setWeight(Weight weight) {}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,11 @@ public abstract class FilterCollector implements Collector {
|
||||||
return in.getLeafCollector(context);
|
return in.getLeafCollector(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setWeight(Weight weight) {
|
||||||
|
in.setWeight(weight);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return getClass().getSimpleName() + "(" + in + ")";
|
return getClass().getSimpleName() + "(" + in + ")";
|
||||||
|
|
|
@ -412,61 +412,13 @@ public class IndexSearcher {
|
||||||
return similarity;
|
return similarity;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class ShortcutHitCountCollector implements Collector {
|
|
||||||
private final Weight weight;
|
|
||||||
private final TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
|
||||||
private int weightCount;
|
|
||||||
|
|
||||||
ShortcutHitCountCollector(Weight weight) {
|
|
||||||
this.weight = weight;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
|
|
||||||
int count = weight.count(context);
|
|
||||||
// check if the number of hits can be computed in constant time
|
|
||||||
if (count == -1) {
|
|
||||||
// use a TotalHitCountCollector to calculate the number of hits in the usual way
|
|
||||||
return totalHitCountCollector.getLeafCollector(context);
|
|
||||||
} else {
|
|
||||||
weightCount += count;
|
|
||||||
throw new CollectionTerminatedException();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ScoreMode scoreMode() {
|
|
||||||
return ScoreMode.COMPLETE_NO_SCORES;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Count how many documents match the given query. May be faster than counting number of hits by
|
* Count how many documents match the given query. May be faster than counting number of hits by
|
||||||
* collecting all matches, as the number of hits is retrieved from the index statistics when
|
* collecting all matches, as the number of hits is retrieved from the index statistics when
|
||||||
* possible.
|
* possible.
|
||||||
*/
|
*/
|
||||||
public int count(Query query) throws IOException {
|
public int count(Query query) throws IOException {
|
||||||
query = rewrite(query, false);
|
return search(new ConstantScoreQuery(query), new TotalHitCountCollectorManager());
|
||||||
final Weight weight = createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1);
|
|
||||||
|
|
||||||
final CollectorManager<ShortcutHitCountCollector, Integer> shortcutCollectorManager =
|
|
||||||
new CollectorManager<ShortcutHitCountCollector, Integer>() {
|
|
||||||
@Override
|
|
||||||
public ShortcutHitCountCollector newCollector() throws IOException {
|
|
||||||
return new ShortcutHitCountCollector(weight);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Integer reduce(Collection<ShortcutHitCountCollector> collectors)
|
|
||||||
throws IOException {
|
|
||||||
int totalHitCount = 0;
|
|
||||||
for (ShortcutHitCountCollector c : collectors) {
|
|
||||||
totalHitCount += c.weightCount + c.totalHitCountCollector.getTotalHits();
|
|
||||||
}
|
|
||||||
return totalHitCount;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
return search(weight, shortcutCollectorManager, new ShortcutHitCountCollector(weight));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -750,6 +702,8 @@ public class IndexSearcher {
|
||||||
protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector)
|
protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
|
collector.setWeight(weight);
|
||||||
|
|
||||||
// TODO: should we make this
|
// TODO: should we make this
|
||||||
// threaded...? the Collector could be sync'd?
|
// threaded...? the Collector could be sync'd?
|
||||||
// always use single thread:
|
// always use single thread:
|
||||||
|
|
|
@ -149,6 +149,13 @@ public class MultiCollector implements Collector {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setWeight(Weight weight) {
|
||||||
|
for (Collector collector : collectors) {
|
||||||
|
collector.setWeight(weight);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Provides access to the wrapped {@code Collector}s for advanced use-cases */
|
/** Provides access to the wrapped {@code Collector}s for advanced use-cases */
|
||||||
public Collector[] getCollectors() {
|
public Collector[] getCollectors() {
|
||||||
return collectors;
|
return collectors;
|
||||||
|
|
|
@ -16,13 +16,16 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.search;
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Just counts the total number of hits. For cases when this is the only collector used, {@link
|
* Just counts the total number of hits. This is the collector behind {@link IndexSearcher#count}.
|
||||||
* IndexSearcher#count(Query)} should be called instead of {@link IndexSearcher#search(Query,
|
* When the {@link Weight} implements {@link Weight#count}, this collector will skip collecting
|
||||||
* Collector)} as the former is faster whenever the count can be returned directly from the index
|
* segments.
|
||||||
* statistics.
|
|
||||||
*/
|
*/
|
||||||
public class TotalHitCountCollector extends SimpleCollector {
|
public class TotalHitCountCollector implements Collector {
|
||||||
|
private Weight weight;
|
||||||
private int totalHits;
|
private int totalHits;
|
||||||
|
|
||||||
/** Returns how many hits matched the search. */
|
/** Returns how many hits matched the search. */
|
||||||
|
@ -30,13 +33,32 @@ public class TotalHitCountCollector extends SimpleCollector {
|
||||||
return totalHits;
|
return totalHits;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void collect(int doc) {
|
|
||||||
totalHits++;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ScoreMode scoreMode() {
|
public ScoreMode scoreMode() {
|
||||||
return ScoreMode.COMPLETE_NO_SCORES;
|
return ScoreMode.COMPLETE_NO_SCORES;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setWeight(Weight weight) {
|
||||||
|
this.weight = weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
|
||||||
|
int leafCount = weight == null ? -1 : weight.count(context);
|
||||||
|
if (leafCount != -1) {
|
||||||
|
totalHits += leafCount;
|
||||||
|
throw new CollectionTerminatedException();
|
||||||
|
}
|
||||||
|
return new LeafCollector() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setScorer(Scorable scorer) throws IOException {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void collect(int doc) throws IOException {
|
||||||
|
totalHits++;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,6 +45,7 @@ import org.apache.lucene.search.similarities.ClassicSimilarity;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
|
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
|
||||||
import org.apache.lucene.tests.search.FixedBitSetCollector;
|
import org.apache.lucene.tests.search.FixedBitSetCollector;
|
||||||
import org.apache.lucene.tests.search.QueryUtils;
|
import org.apache.lucene.tests.search.QueryUtils;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
@ -1021,7 +1022,7 @@ public class TestBooleanQuery extends LuceneTestCase {
|
||||||
builder.setMinimumNumberShouldMatch(TestUtil.nextInt(random(), 0, numShouldClauses));
|
builder.setMinimumNumberShouldMatch(TestUtil.nextInt(random(), 0, numShouldClauses));
|
||||||
Query booleanQuery = builder.build();
|
Query booleanQuery = builder.build();
|
||||||
assertEquals(
|
assertEquals(
|
||||||
(int) searcher.search(booleanQuery, new TotalHitCountCollectorManager()),
|
(int) searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager()),
|
||||||
searcher.count(booleanQuery));
|
searcher.count(booleanQuery));
|
||||||
}
|
}
|
||||||
reader.close();
|
reader.close();
|
||||||
|
|
|
@ -64,6 +64,7 @@ import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.tests.search.AssertingIndexSearcher;
|
import org.apache.lucene.tests.search.AssertingIndexSearcher;
|
||||||
import org.apache.lucene.tests.search.CheckHits;
|
import org.apache.lucene.tests.search.CheckHits;
|
||||||
|
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
import org.apache.lucene.tests.util.RamUsageTester;
|
import org.apache.lucene.tests.util.RamUsageTester;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
|
@ -168,8 +169,8 @@ public class TestLRUQueryCache extends LuceneTestCase {
|
||||||
RandomPicks.randomFrom(
|
RandomPicks.randomFrom(
|
||||||
random(), new String[] {"blue", "red", "yellow", "green"});
|
random(), new String[] {"blue", "red", "yellow", "green"});
|
||||||
final Query q = new TermQuery(new Term("color", value));
|
final Query q = new TermQuery(new Term("color", value));
|
||||||
TotalHitCountCollectorManager collectorManager =
|
CollectorManager<DummyTotalHitCountCollector, Integer> collectorManager =
|
||||||
new TotalHitCountCollectorManager();
|
DummyTotalHitCountCollector.createManager();
|
||||||
// will use the cache
|
// will use the cache
|
||||||
final int totalHits1 = searcher.search(q, collectorManager);
|
final int totalHits1 = searcher.search(q, collectorManager);
|
||||||
final long totalHits2 =
|
final long totalHits2 =
|
||||||
|
@ -177,8 +178,8 @@ public class TestLRUQueryCache extends LuceneTestCase {
|
||||||
q,
|
q,
|
||||||
new CollectorManager<FilterCollector, Integer>() {
|
new CollectorManager<FilterCollector, Integer>() {
|
||||||
@Override
|
@Override
|
||||||
public FilterCollector newCollector() {
|
public FilterCollector newCollector() throws IOException {
|
||||||
return new FilterCollector(new TotalHitCountCollector()) {
|
return new FilterCollector(collectorManager.newCollector()) {
|
||||||
@Override
|
@Override
|
||||||
public ScoreMode scoreMode() {
|
public ScoreMode scoreMode() {
|
||||||
// will not use the cache because of scores
|
// will not use the cache because of scores
|
||||||
|
@ -194,7 +195,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
|
||||||
collectors.stream()
|
collectors.stream()
|
||||||
.map(
|
.map(
|
||||||
filterCollector ->
|
filterCollector ->
|
||||||
(TotalHitCountCollector) filterCollector.in)
|
(DummyTotalHitCountCollector) filterCollector.in)
|
||||||
.collect(Collectors.toList()));
|
.collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -963,7 +964,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
|
||||||
|
|
||||||
searcher.setQueryCache(queryCache);
|
searcher.setQueryCache(queryCache);
|
||||||
searcher.setQueryCachingPolicy(policy);
|
searcher.setQueryCachingPolicy(policy);
|
||||||
searcher.search(query.build(), new TotalHitCountCollectorManager());
|
searcher.search(query.build(), DummyTotalHitCountCollector.createManager());
|
||||||
|
|
||||||
reader.close();
|
reader.close();
|
||||||
dir.close();
|
dir.close();
|
||||||
|
@ -1187,12 +1188,12 @@ public class TestLRUQueryCache extends LuceneTestCase {
|
||||||
searcher.setQueryCachingPolicy(ALWAYS_CACHE);
|
searcher.setQueryCachingPolicy(ALWAYS_CACHE);
|
||||||
|
|
||||||
BadQuery query = new BadQuery();
|
BadQuery query = new BadQuery();
|
||||||
searcher.search(query, new TotalHitCountCollectorManager());
|
searcher.search(query, DummyTotalHitCountCollector.createManager());
|
||||||
query.i[0] += 1; // change the hashCode!
|
query.i[0] += 1; // change the hashCode!
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// trigger an eviction
|
// trigger an eviction
|
||||||
searcher.search(new MatchAllDocsQuery(), new TotalHitCountCollectorManager());
|
searcher.search(new MatchAllDocsQuery(), DummyTotalHitCountCollector.createManager());
|
||||||
fail();
|
fail();
|
||||||
} catch (
|
} catch (
|
||||||
@SuppressWarnings("unused")
|
@SuppressWarnings("unused")
|
||||||
|
@ -1273,7 +1274,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
|
||||||
query.add(bar, Occur.FILTER);
|
query.add(bar, Occur.FILTER);
|
||||||
query.add(foo, Occur.FILTER);
|
query.add(foo, Occur.FILTER);
|
||||||
}
|
}
|
||||||
indexSearcher.search(query.build(), new TotalHitCountCollectorManager());
|
indexSearcher.search(query.build(), DummyTotalHitCountCollector.createManager());
|
||||||
assertEquals(1, policy.frequency(query.build()));
|
assertEquals(1, policy.frequency(query.build()));
|
||||||
assertEquals(1, policy.frequency(foo));
|
assertEquals(1, policy.frequency(foo));
|
||||||
assertEquals(1, policy.frequency(bar));
|
assertEquals(1, policy.frequency(bar));
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.lucene.index.IndexWriter;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
|
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -101,13 +102,13 @@ public class TestMultiCollector extends LuceneTestCase {
|
||||||
final IndexReader reader = w.getReader();
|
final IndexReader reader = w.getReader();
|
||||||
w.close();
|
w.close();
|
||||||
final IndexSearcher searcher = newSearcher(reader, true, true, false);
|
final IndexSearcher searcher = newSearcher(reader, true, true, false);
|
||||||
Map<TotalHitCountCollector, Integer> expectedCounts = new HashMap<>();
|
Map<DummyTotalHitCountCollector, Integer> expectedCounts = new HashMap<>();
|
||||||
List<Collector> collectors = new ArrayList<>();
|
List<Collector> collectors = new ArrayList<>();
|
||||||
final int numCollectors = TestUtil.nextInt(random(), 1, 5);
|
final int numCollectors = TestUtil.nextInt(random(), 1, 5);
|
||||||
for (int i = 0; i < numCollectors; ++i) {
|
for (int i = 0; i < numCollectors; ++i) {
|
||||||
final int terminateAfter = random().nextInt(numDocs + 10);
|
final int terminateAfter = random().nextInt(numDocs + 10);
|
||||||
final int expectedCount = terminateAfter > numDocs ? numDocs : terminateAfter;
|
final int expectedCount = terminateAfter > numDocs ? numDocs : terminateAfter;
|
||||||
TotalHitCountCollector collector = new TotalHitCountCollector();
|
DummyTotalHitCountCollector collector = new DummyTotalHitCountCollector();
|
||||||
expectedCounts.put(collector, expectedCount);
|
expectedCounts.put(collector, expectedCount);
|
||||||
collectors.add(new TerminateAfterCollector(collector, terminateAfter));
|
collectors.add(new TerminateAfterCollector(collector, terminateAfter));
|
||||||
}
|
}
|
||||||
|
@ -124,7 +125,8 @@ public class TestMultiCollector extends LuceneTestCase {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
for (Map.Entry<TotalHitCountCollector, Integer> expectedCount : expectedCounts.entrySet()) {
|
for (Map.Entry<DummyTotalHitCountCollector, Integer> expectedCount :
|
||||||
|
expectedCounts.entrySet()) {
|
||||||
assertEquals(expectedCount.getValue().intValue(), expectedCount.getKey().getTotalHits());
|
assertEquals(expectedCount.getValue().intValue(), expectedCount.getKey().getTotalHits());
|
||||||
}
|
}
|
||||||
reader.close();
|
reader.close();
|
||||||
|
@ -133,8 +135,8 @@ public class TestMultiCollector extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSetScorerAfterCollectionTerminated() throws IOException {
|
public void testSetScorerAfterCollectionTerminated() throws IOException {
|
||||||
Collector collector1 = new TotalHitCountCollector();
|
Collector collector1 = new DummyTotalHitCountCollector();
|
||||||
Collector collector2 = new TotalHitCountCollector();
|
Collector collector2 = new DummyTotalHitCountCollector();
|
||||||
|
|
||||||
AtomicBoolean setScorerCalled1 = new AtomicBoolean();
|
AtomicBoolean setScorerCalled1 = new AtomicBoolean();
|
||||||
collector1 = new SetScorerCollector(collector1, setScorerCalled1);
|
collector1 = new SetScorerCollector(collector1, setScorerCalled1);
|
||||||
|
@ -224,7 +226,7 @@ public class TestMultiCollector extends LuceneTestCase {
|
||||||
scorer.setMinCompetitiveScore(minScore);
|
scorer.setMinCompetitiveScore(minScore);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Collector multiCollector = MultiCollector.wrap(collector, new TotalHitCountCollector());
|
Collector multiCollector = MultiCollector.wrap(collector, new DummyTotalHitCountCollector());
|
||||||
LeafCollector leafCollector = multiCollector.getLeafCollector(reader.leaves().get(0));
|
LeafCollector leafCollector = multiCollector.getLeafCollector(reader.leaves().get(0));
|
||||||
leafCollector.setScorer(scorer);
|
leafCollector.setScorer(scorer);
|
||||||
leafCollector.collect(0); // no exception
|
leafCollector.collect(0); // no exception
|
||||||
|
@ -283,7 +285,7 @@ public class TestMultiCollector extends LuceneTestCase {
|
||||||
List<Collector> cols = new ArrayList<>();
|
List<Collector> cols = new ArrayList<>();
|
||||||
cols.add(collector);
|
cols.add(collector);
|
||||||
for (int col = 0; col < numCol; col++) {
|
for (int col = 0; col < numCol; col++) {
|
||||||
cols.add(new TerminateAfterCollector(new TotalHitCountCollector(), 0));
|
cols.add(new TerminateAfterCollector(new DummyTotalHitCountCollector(), 0));
|
||||||
}
|
}
|
||||||
Collections.shuffle(cols, random());
|
Collections.shuffle(cols, random());
|
||||||
Collector multiCollector = MultiCollector.wrap(cols);
|
Collector multiCollector = MultiCollector.wrap(cols);
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
|
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
|
||||||
public class TestSearchWithThreads extends LuceneTestCase {
|
public class TestSearchWithThreads extends LuceneTestCase {
|
||||||
|
@ -57,7 +58,7 @@ public class TestSearchWithThreads extends LuceneTestCase {
|
||||||
|
|
||||||
final AtomicBoolean failed = new AtomicBoolean();
|
final AtomicBoolean failed = new AtomicBoolean();
|
||||||
final AtomicLong netSearch = new AtomicLong();
|
final AtomicLong netSearch = new AtomicLong();
|
||||||
TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
|
CollectorManager<?, Integer> collectorManager = DummyTotalHitCountCollector.createManager();
|
||||||
Thread[] threads = new Thread[numThreads];
|
Thread[] threads = new Thread[numThreads];
|
||||||
for (int threadID = 0; threadID < numThreads; threadID++) {
|
for (int threadID = 0; threadID < numThreads; threadID++) {
|
||||||
threads[threadID] =
|
threads[threadID] =
|
||||||
|
|
|
@ -34,6 +34,7 @@ import org.apache.lucene.index.Terms;
|
||||||
import org.apache.lucene.index.TermsEnum;
|
import org.apache.lucene.index.TermsEnum;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
|
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
|
||||||
import org.apache.lucene.tests.search.QueryUtils;
|
import org.apache.lucene.tests.search.QueryUtils;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
|
@ -91,14 +92,13 @@ public class TestTermQuery extends LuceneTestCase {
|
||||||
IndexSearcher searcher = new IndexSearcher(reader);
|
IndexSearcher searcher = new IndexSearcher(reader);
|
||||||
// use a collector rather than searcher.count() which would just read the
|
// use a collector rather than searcher.count() which would just read the
|
||||||
// doc freq instead of creating a scorer
|
// doc freq instead of creating a scorer
|
||||||
TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
|
int totalHits = searcher.search(query, DummyTotalHitCountCollector.createManager());
|
||||||
int totalHits = searcher.search(query, collectorManager);
|
|
||||||
assertEquals(1, totalHits);
|
assertEquals(1, totalHits);
|
||||||
TermQuery queryWithContext =
|
TermQuery queryWithContext =
|
||||||
new TermQuery(
|
new TermQuery(
|
||||||
new Term("foo", "bar"),
|
new Term("foo", "bar"),
|
||||||
TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
|
TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
|
||||||
totalHits = searcher.search(queryWithContext, collectorManager);
|
totalHits = searcher.search(queryWithContext, DummyTotalHitCountCollector.createManager());
|
||||||
assertEquals(1, totalHits);
|
assertEquals(1, totalHits);
|
||||||
|
|
||||||
IOUtils.close(reader, w, dir);
|
IOUtils.close(reader, w, dir);
|
||||||
|
|
|
@ -20,6 +20,8 @@ import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.StringField;
|
import org.apache.lucene.document.StringField;
|
||||||
import org.apache.lucene.index.IndexReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.search.BooleanClause.Occur;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
@ -42,6 +44,15 @@ public class TestTotalHitCountCollector extends LuceneTestCase {
|
||||||
TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
|
TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
|
||||||
int totalHits = searcher.search(new MatchAllDocsQuery(), collectorManager);
|
int totalHits = searcher.search(new MatchAllDocsQuery(), collectorManager);
|
||||||
assertEquals(5, totalHits);
|
assertEquals(5, totalHits);
|
||||||
|
|
||||||
|
Query query =
|
||||||
|
new BooleanQuery.Builder()
|
||||||
|
.add(new TermQuery(new Term("string", "a1")), Occur.SHOULD)
|
||||||
|
.add(new TermQuery(new Term("string", "b3")), Occur.SHOULD)
|
||||||
|
.build();
|
||||||
|
totalHits = searcher.search(query, collectorManager);
|
||||||
|
assertEquals(2, totalHits);
|
||||||
|
|
||||||
reader.close();
|
reader.close();
|
||||||
indexStore.close();
|
indexStore.close();
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.search.Collector;
|
import org.apache.lucene.search.Collector;
|
||||||
import org.apache.lucene.search.LeafCollector;
|
import org.apache.lucene.search.LeafCollector;
|
||||||
import org.apache.lucene.search.ScoreMode;
|
import org.apache.lucene.search.ScoreMode;
|
||||||
|
import org.apache.lucene.search.Weight;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class wraps a Collector and times the execution of: - setScorer() - collect() -
|
* This class wraps a Collector and times the execution of: - setScorer() - collect() -
|
||||||
|
@ -83,6 +84,11 @@ public class ProfilerCollector implements Collector {
|
||||||
return collector.getLeafCollector(context);
|
return collector.getLeafCollector(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setWeight(Weight weight) {
|
||||||
|
collector.setWeight(weight);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ScoreMode scoreMode() {
|
public ScoreMode scoreMode() {
|
||||||
return collector.scoreMode();
|
return collector.scoreMode();
|
||||||
|
|
|
@ -40,11 +40,11 @@ import org.apache.lucene.search.Sort;
|
||||||
import org.apache.lucene.search.SortField;
|
import org.apache.lucene.search.SortField;
|
||||||
import org.apache.lucene.search.SortedNumericSortField;
|
import org.apache.lucene.search.SortedNumericSortField;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TotalHitCountCollectorManager;
|
|
||||||
import org.apache.lucene.search.Weight;
|
import org.apache.lucene.search.Weight;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
|
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
|
||||||
import org.apache.lucene.tests.search.QueryUtils;
|
import org.apache.lucene.tests.search.QueryUtils;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
|
@ -221,7 +221,8 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas
|
||||||
private static void assertNumberOfHits(IndexSearcher searcher, Query query, int numberOfHits)
|
private static void assertNumberOfHits(IndexSearcher searcher, Query query, int numberOfHits)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
assertEquals(
|
assertEquals(
|
||||||
numberOfHits, searcher.search(query, new TotalHitCountCollectorManager()).intValue());
|
numberOfHits,
|
||||||
|
searcher.search(query, DummyTotalHitCountCollector.createManager()).intValue());
|
||||||
assertEquals(numberOfHits, searcher.count(query));
|
assertEquals(numberOfHits, searcher.count(query));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,9 +37,9 @@ import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.ScoreMode;
|
import org.apache.lucene.search.ScoreMode;
|
||||||
import org.apache.lucene.search.Sort;
|
import org.apache.lucene.search.Sort;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TotalHitCountCollectorManager;
|
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
|
import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
|
||||||
import org.apache.lucene.tests.search.QueryUtils;
|
import org.apache.lucene.tests.search.QueryUtils;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
|
@ -808,8 +808,8 @@ public class TestMultiRangeQueries extends LuceneTestCase {
|
||||||
|
|
||||||
MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader);
|
MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader);
|
||||||
BooleanQuery booleanQuery = builder2.build();
|
BooleanQuery booleanQuery = builder2.build();
|
||||||
int count = searcher.search(multiRangeQuery, new TotalHitCountCollectorManager());
|
int count = searcher.search(multiRangeQuery, DummyTotalHitCountCollector.createManager());
|
||||||
int booleanCount = searcher.search(booleanQuery, new TotalHitCountCollectorManager());
|
int booleanCount = searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager());
|
||||||
assertEquals(booleanCount, count);
|
assertEquals(booleanCount, count);
|
||||||
}
|
}
|
||||||
IOUtils.close(reader, w, dir);
|
IOUtils.close(reader, w, dir);
|
||||||
|
|
|
@ -22,10 +22,12 @@ import org.apache.lucene.search.Collector;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.FilterCollector;
|
import org.apache.lucene.search.FilterCollector;
|
||||||
import org.apache.lucene.search.LeafCollector;
|
import org.apache.lucene.search.LeafCollector;
|
||||||
|
import org.apache.lucene.search.Weight;
|
||||||
|
|
||||||
/** A collector that asserts that it is used correctly. */
|
/** A collector that asserts that it is used correctly. */
|
||||||
class AssertingCollector extends FilterCollector {
|
class AssertingCollector extends FilterCollector {
|
||||||
|
|
||||||
|
private boolean weightSet = false;
|
||||||
private int maxDoc = -1;
|
private int maxDoc = -1;
|
||||||
private int previousLeafMaxDoc = 0;
|
private int previousLeafMaxDoc = 0;
|
||||||
|
|
||||||
|
@ -43,6 +45,7 @@ class AssertingCollector extends FilterCollector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
|
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
|
||||||
|
assert weightSet : "Set the weight first";
|
||||||
assert context.docBase >= previousLeafMaxDoc;
|
assert context.docBase >= previousLeafMaxDoc;
|
||||||
previousLeafMaxDoc = context.docBase + context.reader().maxDoc();
|
previousLeafMaxDoc = context.docBase + context.reader().maxDoc();
|
||||||
|
|
||||||
|
@ -65,4 +68,12 @@ class AssertingCollector extends FilterCollector {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setWeight(Weight weight) {
|
||||||
|
assert weightSet == false : "Weight set twice";
|
||||||
|
weightSet = true;
|
||||||
|
assert weight != null;
|
||||||
|
in.setWeight(weight);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.tests.search;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collection;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.search.Collector;
|
||||||
|
import org.apache.lucene.search.CollectorManager;
|
||||||
|
import org.apache.lucene.search.LeafCollector;
|
||||||
|
import org.apache.lucene.search.Scorable;
|
||||||
|
import org.apache.lucene.search.ScoreMode;
|
||||||
|
import org.apache.lucene.search.TotalHitCountCollector;
|
||||||
|
import org.apache.lucene.search.Weight;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A dummy version of {@link TotalHitCountCollector} that doesn't shortcut using {@link
|
||||||
|
* Weight#count}.
|
||||||
|
*/
|
||||||
|
public class DummyTotalHitCountCollector implements Collector {
|
||||||
|
private int totalHits;
|
||||||
|
|
||||||
|
/** Constructor */
|
||||||
|
public DummyTotalHitCountCollector() {}
|
||||||
|
|
||||||
|
/** Get the number of hits. */
|
||||||
|
public int getTotalHits() {
|
||||||
|
return totalHits;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ScoreMode scoreMode() {
|
||||||
|
return ScoreMode.COMPLETE_NO_SCORES;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
|
||||||
|
return new LeafCollector() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setScorer(Scorable scorer) throws IOException {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void collect(int doc) throws IOException {
|
||||||
|
totalHits++;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Create a collector manager. */
|
||||||
|
public static CollectorManager<DummyTotalHitCountCollector, Integer> createManager() {
|
||||||
|
return new CollectorManager<DummyTotalHitCountCollector, Integer>() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DummyTotalHitCountCollector newCollector() throws IOException {
|
||||||
|
return new DummyTotalHitCountCollector();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer reduce(Collection<DummyTotalHitCountCollector> collectors) throws IOException {
|
||||||
|
int sum = 0;
|
||||||
|
for (DummyTotalHitCountCollector coll : collectors) {
|
||||||
|
sum += coll.totalHits;
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue