diff --git a/lucene/core/src/java/org/apache/lucene/search/Collector.java b/lucene/core/src/java/org/apache/lucene/search/Collector.java index 3af2210bd6e..7c02e446755 100644 --- a/lucene/core/src/java/org/apache/lucene/search/Collector.java +++ b/lucene/core/src/java/org/apache/lucene/search/Collector.java @@ -56,4 +56,11 @@ public interface Collector { /** Indicates what features are required from the scorer. */ ScoreMode scoreMode(); + + /** + * Set the {@link Weight} that will be used to produce scorers that will feed {@link + * LeafCollector}s. This is typically useful to have access to {@link Weight#count} from {@link + * Collector#getLeafCollector}. + */ + default void setWeight(Weight weight) {} } diff --git a/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java b/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java index 9f57a57b00a..32c52abe1b7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java @@ -38,6 +38,11 @@ public abstract class FilterCollector implements Collector { return in.getLeafCollector(context); } + @Override + public void setWeight(Weight weight) { + in.setWeight(weight); + } + @Override public String toString() { return getClass().getSimpleName() + "(" + in + ")"; diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java index 4e64089f3c1..0f6bdfdeb10 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java @@ -412,61 +412,13 @@ public class IndexSearcher { return similarity; } - private static class ShortcutHitCountCollector implements Collector { - private final Weight weight; - private final TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); - private int weightCount; - - ShortcutHitCountCollector(Weight weight) { - this.weight = weight; - } - - @Override - public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { - int count = weight.count(context); - // check if the number of hits can be computed in constant time - if (count == -1) { - // use a TotalHitCountCollector to calculate the number of hits in the usual way - return totalHitCountCollector.getLeafCollector(context); - } else { - weightCount += count; - throw new CollectionTerminatedException(); - } - } - - @Override - public ScoreMode scoreMode() { - return ScoreMode.COMPLETE_NO_SCORES; - } - } - /** * Count how many documents match the given query. May be faster than counting number of hits by * collecting all matches, as the number of hits is retrieved from the index statistics when * possible. */ public int count(Query query) throws IOException { - query = rewrite(query, false); - final Weight weight = createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1); - - final CollectorManager shortcutCollectorManager = - new CollectorManager() { - @Override - public ShortcutHitCountCollector newCollector() throws IOException { - return new ShortcutHitCountCollector(weight); - } - - @Override - public Integer reduce(Collection collectors) - throws IOException { - int totalHitCount = 0; - for (ShortcutHitCountCollector c : collectors) { - totalHitCount += c.weightCount + c.totalHitCountCollector.getTotalHits(); - } - return totalHitCount; - } - }; - return search(weight, shortcutCollectorManager, new ShortcutHitCountCollector(weight)); + return search(new ConstantScoreQuery(query), new TotalHitCountCollectorManager()); } /** @@ -750,6 +702,8 @@ public class IndexSearcher { protected void search(List leaves, Weight weight, Collector collector) throws IOException { + collector.setWeight(weight); + // TODO: should we make this // threaded...? the Collector could be sync'd? // always use single thread: diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java b/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java index 09aea3a02a3..5452c0f8d69 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java @@ -149,6 +149,13 @@ public class MultiCollector implements Collector { } } + @Override + public void setWeight(Weight weight) { + for (Collector collector : collectors) { + collector.setWeight(weight); + } + } + /** Provides access to the wrapped {@code Collector}s for advanced use-cases */ public Collector[] getCollectors() { return collectors; diff --git a/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java b/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java index 9d9ad4149b0..30d0659f2cd 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java @@ -16,13 +16,16 @@ */ package org.apache.lucene.search; +import java.io.IOException; +import org.apache.lucene.index.LeafReaderContext; + /** - * Just counts the total number of hits. For cases when this is the only collector used, {@link - * IndexSearcher#count(Query)} should be called instead of {@link IndexSearcher#search(Query, - * Collector)} as the former is faster whenever the count can be returned directly from the index - * statistics. + * Just counts the total number of hits. This is the collector behind {@link IndexSearcher#count}. + * When the {@link Weight} implements {@link Weight#count}, this collector will skip collecting + * segments. */ -public class TotalHitCountCollector extends SimpleCollector { +public class TotalHitCountCollector implements Collector { + private Weight weight; private int totalHits; /** Returns how many hits matched the search. */ @@ -30,13 +33,32 @@ public class TotalHitCountCollector extends SimpleCollector { return totalHits; } - @Override - public void collect(int doc) { - totalHits++; - } - @Override public ScoreMode scoreMode() { return ScoreMode.COMPLETE_NO_SCORES; } + + @Override + public void setWeight(Weight weight) { + this.weight = weight; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + int leafCount = weight == null ? -1 : weight.count(context); + if (leafCount != -1) { + totalHits += leafCount; + throw new CollectionTerminatedException(); + } + return new LeafCollector() { + + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + totalHits++; + } + }; + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java index 6a391a38274..a5710121acf 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java @@ -45,6 +45,7 @@ import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.DummyTotalHitCountCollector; import org.apache.lucene.tests.search.FixedBitSetCollector; import org.apache.lucene.tests.search.QueryUtils; import org.apache.lucene.tests.util.LuceneTestCase; @@ -1021,7 +1022,7 @@ public class TestBooleanQuery extends LuceneTestCase { builder.setMinimumNumberShouldMatch(TestUtil.nextInt(random(), 0, numShouldClauses)); Query booleanQuery = builder.build(); assertEquals( - (int) searcher.search(booleanQuery, new TotalHitCountCollectorManager()), + (int) searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager()), searcher.count(booleanQuery)); } reader.close(); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java b/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java index a30bb757e60..10826517b1d 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java @@ -64,6 +64,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.search.AssertingIndexSearcher; import org.apache.lucene.tests.search.CheckHits; +import org.apache.lucene.tests.search.DummyTotalHitCountCollector; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.RamUsageTester; import org.apache.lucene.tests.util.TestUtil; @@ -168,8 +169,8 @@ public class TestLRUQueryCache extends LuceneTestCase { RandomPicks.randomFrom( random(), new String[] {"blue", "red", "yellow", "green"}); final Query q = new TermQuery(new Term("color", value)); - TotalHitCountCollectorManager collectorManager = - new TotalHitCountCollectorManager(); + CollectorManager collectorManager = + DummyTotalHitCountCollector.createManager(); // will use the cache final int totalHits1 = searcher.search(q, collectorManager); final long totalHits2 = @@ -177,8 +178,8 @@ public class TestLRUQueryCache extends LuceneTestCase { q, new CollectorManager() { @Override - public FilterCollector newCollector() { - return new FilterCollector(new TotalHitCountCollector()) { + public FilterCollector newCollector() throws IOException { + return new FilterCollector(collectorManager.newCollector()) { @Override public ScoreMode scoreMode() { // will not use the cache because of scores @@ -194,7 +195,7 @@ public class TestLRUQueryCache extends LuceneTestCase { collectors.stream() .map( filterCollector -> - (TotalHitCountCollector) filterCollector.in) + (DummyTotalHitCountCollector) filterCollector.in) .collect(Collectors.toList())); } }); @@ -963,7 +964,7 @@ public class TestLRUQueryCache extends LuceneTestCase { searcher.setQueryCache(queryCache); searcher.setQueryCachingPolicy(policy); - searcher.search(query.build(), new TotalHitCountCollectorManager()); + searcher.search(query.build(), DummyTotalHitCountCollector.createManager()); reader.close(); dir.close(); @@ -1187,12 +1188,12 @@ public class TestLRUQueryCache extends LuceneTestCase { searcher.setQueryCachingPolicy(ALWAYS_CACHE); BadQuery query = new BadQuery(); - searcher.search(query, new TotalHitCountCollectorManager()); + searcher.search(query, DummyTotalHitCountCollector.createManager()); query.i[0] += 1; // change the hashCode! try { // trigger an eviction - searcher.search(new MatchAllDocsQuery(), new TotalHitCountCollectorManager()); + searcher.search(new MatchAllDocsQuery(), DummyTotalHitCountCollector.createManager()); fail(); } catch ( @SuppressWarnings("unused") @@ -1273,7 +1274,7 @@ public class TestLRUQueryCache extends LuceneTestCase { query.add(bar, Occur.FILTER); query.add(foo, Occur.FILTER); } - indexSearcher.search(query.build(), new TotalHitCountCollectorManager()); + indexSearcher.search(query.build(), DummyTotalHitCountCollector.createManager()); assertEquals(1, policy.frequency(query.build())); assertEquals(1, policy.frequency(foo)); assertEquals(1, policy.frequency(bar)); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java index 9f8a2f5902d..a8fc829bcc1 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java @@ -32,6 +32,7 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.DummyTotalHitCountCollector; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; import org.junit.Test; @@ -101,13 +102,13 @@ public class TestMultiCollector extends LuceneTestCase { final IndexReader reader = w.getReader(); w.close(); final IndexSearcher searcher = newSearcher(reader, true, true, false); - Map expectedCounts = new HashMap<>(); + Map expectedCounts = new HashMap<>(); List collectors = new ArrayList<>(); final int numCollectors = TestUtil.nextInt(random(), 1, 5); for (int i = 0; i < numCollectors; ++i) { final int terminateAfter = random().nextInt(numDocs + 10); final int expectedCount = terminateAfter > numDocs ? numDocs : terminateAfter; - TotalHitCountCollector collector = new TotalHitCountCollector(); + DummyTotalHitCountCollector collector = new DummyTotalHitCountCollector(); expectedCounts.put(collector, expectedCount); collectors.add(new TerminateAfterCollector(collector, terminateAfter)); } @@ -124,7 +125,8 @@ public class TestMultiCollector extends LuceneTestCase { return null; } }); - for (Map.Entry expectedCount : expectedCounts.entrySet()) { + for (Map.Entry expectedCount : + expectedCounts.entrySet()) { assertEquals(expectedCount.getValue().intValue(), expectedCount.getKey().getTotalHits()); } reader.close(); @@ -133,8 +135,8 @@ public class TestMultiCollector extends LuceneTestCase { } public void testSetScorerAfterCollectionTerminated() throws IOException { - Collector collector1 = new TotalHitCountCollector(); - Collector collector2 = new TotalHitCountCollector(); + Collector collector1 = new DummyTotalHitCountCollector(); + Collector collector2 = new DummyTotalHitCountCollector(); AtomicBoolean setScorerCalled1 = new AtomicBoolean(); collector1 = new SetScorerCollector(collector1, setScorerCalled1); @@ -224,7 +226,7 @@ public class TestMultiCollector extends LuceneTestCase { scorer.setMinCompetitiveScore(minScore); } }; - Collector multiCollector = MultiCollector.wrap(collector, new TotalHitCountCollector()); + Collector multiCollector = MultiCollector.wrap(collector, new DummyTotalHitCountCollector()); LeafCollector leafCollector = multiCollector.getLeafCollector(reader.leaves().get(0)); leafCollector.setScorer(scorer); leafCollector.collect(0); // no exception @@ -283,7 +285,7 @@ public class TestMultiCollector extends LuceneTestCase { List cols = new ArrayList<>(); cols.add(collector); for (int col = 0; col < numCol; col++) { - cols.add(new TerminateAfterCollector(new TotalHitCountCollector(), 0)); + cols.add(new TerminateAfterCollector(new DummyTotalHitCountCollector(), 0)); } Collections.shuffle(cols, random()); Collector multiCollector = MultiCollector.wrap(cols); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java b/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java index 17bcce0ee15..d4bd95fca7a 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java @@ -24,6 +24,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.DummyTotalHitCountCollector; import org.apache.lucene.tests.util.LuceneTestCase; public class TestSearchWithThreads extends LuceneTestCase { @@ -57,7 +58,7 @@ public class TestSearchWithThreads extends LuceneTestCase { final AtomicBoolean failed = new AtomicBoolean(); final AtomicLong netSearch = new AtomicLong(); - TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager(); + CollectorManager collectorManager = DummyTotalHitCountCollector.createManager(); Thread[] threads = new Thread[numThreads]; for (int threadID = 0; threadID < numThreads; threadID++) { threads[threadID] = diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java index 5ad4392355b..9a6f39de726 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java @@ -34,6 +34,7 @@ import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.DummyTotalHitCountCollector; import org.apache.lucene.tests.search.QueryUtils; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -91,14 +92,13 @@ public class TestTermQuery extends LuceneTestCase { IndexSearcher searcher = new IndexSearcher(reader); // use a collector rather than searcher.count() which would just read the // doc freq instead of creating a scorer - TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager(); - int totalHits = searcher.search(query, collectorManager); + int totalHits = searcher.search(query, DummyTotalHitCountCollector.createManager()); assertEquals(1, totalHits); TermQuery queryWithContext = new TermQuery( new Term("foo", "bar"), TermStates.build(reader.getContext(), new Term("foo", "bar"), true)); - totalHits = searcher.search(queryWithContext, collectorManager); + totalHits = searcher.search(queryWithContext, DummyTotalHitCountCollector.createManager()); assertEquals(1, totalHits); IOUtils.close(reader, w, dir); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java index 49049ebd378..eb2afb58e34 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java @@ -20,6 +20,8 @@ import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.StringField; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; @@ -42,6 +44,15 @@ public class TestTotalHitCountCollector extends LuceneTestCase { TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager(); int totalHits = searcher.search(new MatchAllDocsQuery(), collectorManager); assertEquals(5, totalHits); + + Query query = + new BooleanQuery.Builder() + .add(new TermQuery(new Term("string", "a1")), Occur.SHOULD) + .add(new TermQuery(new Term("string", "b3")), Occur.SHOULD) + .build(); + totalHits = searcher.search(query, collectorManager); + assertEquals(2, totalHits); + reader.close(); indexStore.close(); } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java index b38f7d626f1..94f56b93189 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java @@ -24,6 +24,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Collector; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; /** * This class wraps a Collector and times the execution of: - setScorer() - collect() - @@ -83,6 +84,11 @@ public class ProfilerCollector implements Collector { return collector.getLeafCollector(context); } + @Override + public void setWeight(Weight weight) { + collector.setWeight(weight); + } + @Override public ScoreMode scoreMode() { return collector.scoreMode(); diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java index c10110af812..4323a005c17 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java @@ -40,11 +40,11 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortedNumericSortField; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHitCountCollectorManager; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.DummyTotalHitCountCollector; import org.apache.lucene.tests.search.QueryUtils; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -221,7 +221,8 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas private static void assertNumberOfHits(IndexSearcher searcher, Query query, int numberOfHits) throws IOException { assertEquals( - numberOfHits, searcher.search(query, new TotalHitCountCollectorManager()).intValue()); + numberOfHits, + searcher.search(query, DummyTotalHitCountCollector.createManager()).intValue()); assertEquals(numberOfHits, searcher.count(query)); } diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java index fa420b608a0..03b92418199 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java @@ -37,9 +37,9 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHitCountCollectorManager; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.DummyTotalHitCountCollector; import org.apache.lucene.tests.search.QueryUtils; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -808,8 +808,8 @@ public class TestMultiRangeQueries extends LuceneTestCase { MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader); BooleanQuery booleanQuery = builder2.build(); - int count = searcher.search(multiRangeQuery, new TotalHitCountCollectorManager()); - int booleanCount = searcher.search(booleanQuery, new TotalHitCountCollectorManager()); + int count = searcher.search(multiRangeQuery, DummyTotalHitCountCollector.createManager()); + int booleanCount = searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager()); assertEquals(booleanCount, count); } IOUtils.close(reader, w, dir); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java index 7ffa9350e2b..cf2c2732614 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java @@ -22,10 +22,12 @@ import org.apache.lucene.search.Collector; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FilterCollector; import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Weight; /** A collector that asserts that it is used correctly. */ class AssertingCollector extends FilterCollector { + private boolean weightSet = false; private int maxDoc = -1; private int previousLeafMaxDoc = 0; @@ -43,6 +45,7 @@ class AssertingCollector extends FilterCollector { @Override public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + assert weightSet : "Set the weight first"; assert context.docBase >= previousLeafMaxDoc; previousLeafMaxDoc = context.docBase + context.reader().maxDoc(); @@ -65,4 +68,12 @@ class AssertingCollector extends FilterCollector { } }; } + + @Override + public void setWeight(Weight weight) { + assert weightSet == false : "Weight set twice"; + weightSet = true; + assert weight != null; + in.setWeight(weight); + } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/DummyTotalHitCountCollector.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/DummyTotalHitCountCollector.java new file mode 100644 index 00000000000..fcb53b96f0a --- /dev/null +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/DummyTotalHitCountCollector.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.tests.search; + +import java.io.IOException; +import java.util.Collection; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TotalHitCountCollector; +import org.apache.lucene.search.Weight; + +/** + * A dummy version of {@link TotalHitCountCollector} that doesn't shortcut using {@link + * Weight#count}. + */ +public class DummyTotalHitCountCollector implements Collector { + private int totalHits; + + /** Constructor */ + public DummyTotalHitCountCollector() {} + + /** Get the number of hits. */ + public int getTotalHits() { + return totalHits; + } + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + return new LeafCollector() { + + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + totalHits++; + } + }; + } + + /** Create a collector manager. */ + public static CollectorManager createManager() { + return new CollectorManager() { + + @Override + public DummyTotalHitCountCollector newCollector() throws IOException { + return new DummyTotalHitCountCollector(); + } + + @Override + public Integer reduce(Collection collectors) throws IOException { + int sum = 0; + for (DummyTotalHitCountCollector coll : collectors) { + sum += coll.totalHits; + } + return sum; + } + }; + } +}