diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 5b0b3521d1a..727c0fa5e9d 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -124,6 +124,10 @@ New Features * LUCENE-10250: Add support for arbitrary length hierarchical SSDV facets. (Marc D'mello) +* LUCENE-10395: Add support for TotalHitCountCollectorManager, a collector manager + based on TotalHitCountCollector that allows users to parallelize counting the + number of hits. (Luca Cavanna, Adrien Grand) + Improvements --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollectorManager.java new file mode 100644 index 00000000000..664602a4e5e --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollectorManager.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Collection; + +/** + * Collector manager based on {@link TotalHitCountCollector} that allows users to parallelize + * counting the number of hits, expected to be used mostly wrapped in {@link MultiCollectorManager}. + * For cases when this is the only collector manager used, {@link IndexSearcher#count(Query)} should + * be called instead of {@link IndexSearcher#search(Query, CollectorManager)} as the former is + * faster whenever the count can be returned directly from the index statistics. + */ +public class TotalHitCountCollectorManager + implements CollectorManager { + @Override + public TotalHitCountCollector newCollector() throws IOException { + return new TotalHitCountCollector(); + } + + @Override + public Integer reduce(Collection collectors) throws IOException { + int totalHits = 0; + for (TotalHitCountCollector collector : collectors) { + totalHits += collector.getTotalHits(); + } + return totalHits; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java b/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java index c7223a4a6d6..e23683bfee0 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java @@ -39,6 +39,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.LongPoint; @@ -167,19 +168,36 @@ public class TestLRUQueryCache extends LuceneTestCase { RandomPicks.randomFrom( random(), new String[] {"blue", "red", "yellow", "green"}); final Query q = new TermQuery(new Term("color", value)); - TotalHitCountCollector collector = new TotalHitCountCollector(); - searcher.search(q, collector); // will use the cache - final int totalHits1 = collector.getTotalHits(); - TotalHitCountCollector collector2 = new TotalHitCountCollector(); - searcher.search( - q, - new FilterCollector(collector2) { - @Override - public ScoreMode scoreMode() { - return ScoreMode.COMPLETE; // will not use the cache because of scores - } - }); - final long totalHits2 = collector2.getTotalHits(); + TotalHitCountCollectorManager collectorManager = + new TotalHitCountCollectorManager(); + // will use the cache + final int totalHits1 = searcher.search(q, collectorManager); + final long totalHits2 = + searcher.search( + q, + new CollectorManager() { + @Override + public FilterCollector newCollector() { + return new FilterCollector(new TotalHitCountCollector()) { + @Override + public ScoreMode scoreMode() { + // will not use the cache because of scores + return ScoreMode.COMPLETE; + } + }; + } + + @Override + public Integer reduce(Collection collectors) + throws IOException { + return collectorManager.reduce( + collectors.stream() + .map( + filterCollector -> + (TotalHitCountCollector) filterCollector.in) + .collect(Collectors.toList())); + } + }); assertEquals(totalHits2, totalHits1); } finally { mgr.release(searcher); @@ -945,7 +963,7 @@ public class TestLRUQueryCache extends LuceneTestCase { searcher.setQueryCache(queryCache); searcher.setQueryCachingPolicy(policy); - searcher.search(query.build(), new TotalHitCountCollector()); + searcher.search(query.build(), new TotalHitCountCollectorManager()); reader.close(); dir.close(); @@ -1174,7 +1192,7 @@ public class TestLRUQueryCache extends LuceneTestCase { try { // trigger an eviction - searcher.search(new MatchAllDocsQuery(), new TotalHitCountCollector()); + searcher.search(new MatchAllDocsQuery(), new TotalHitCountCollectorManager()); fail(); } catch ( @SuppressWarnings("unused") diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java b/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java index 9bb14ca448c..17bcce0ee15 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java @@ -57,12 +57,11 @@ public class TestSearchWithThreads extends LuceneTestCase { final AtomicBoolean failed = new AtomicBoolean(); final AtomicLong netSearch = new AtomicLong(); - + TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager(); Thread[] threads = new Thread[numThreads]; for (int threadID = 0; threadID < numThreads; threadID++) { threads[threadID] = new Thread() { - TotalHitCountCollector col = new TotalHitCountCollector(); @Override public void run() { @@ -70,10 +69,8 @@ public class TestSearchWithThreads extends LuceneTestCase { long totHits = 0; long totSearch = 0; for (; totSearch < numSearches & !failed.get(); totSearch++) { - s.search(new TermQuery(new Term("body", "aaa")), col); - totHits += col.getTotalHits(); - s.search(new TermQuery(new Term("body", "bbb")), col); - totHits += col.getTotalHits(); + totHits += s.search(new TermQuery(new Term("body", "aaa")), collectorManager); + totHits += s.search(new TermQuery(new Term("body", "bbb")), collectorManager); } assertTrue(totSearch > 0 && totHits > 0); netSearch.addAndGet(totSearch); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java index d73675ec1b8..5ad4392355b 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java @@ -91,16 +91,15 @@ public class TestTermQuery extends LuceneTestCase { IndexSearcher searcher = new IndexSearcher(reader); // use a collector rather than searcher.count() which would just read the // doc freq instead of creating a scorer - TotalHitCountCollector collector = new TotalHitCountCollector(); - searcher.search(query, collector); - assertEquals(1, collector.getTotalHits()); + TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager(); + int totalHits = searcher.search(query, collectorManager); + assertEquals(1, totalHits); TermQuery queryWithContext = new TermQuery( new Term("foo", "bar"), TermStates.build(reader.getContext(), new Term("foo", "bar"), true)); - collector = new TotalHitCountCollector(); - searcher.search(queryWithContext, collector); - assertEquals(1, collector.getTotalHits()); + totalHits = searcher.search(queryWithContext, collectorManager); + assertEquals(1, totalHits); IOUtils.close(reader, w, dir); } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollectorManager.java b/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollectorManager.java new file mode 100644 index 00000000000..0eb2652b996 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollectorManager.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestTotalHitCountCollectorManager extends LuceneTestCase { + + public void testBasics() throws Exception { + Directory indexStore = newDirectory(); + RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore); + for (int i = 0; i < 5; i++) { + Document doc = new Document(); + doc.add(new StringField("string", "a" + i, Field.Store.NO)); + doc.add(new StringField("string", "b" + i, Field.Store.NO)); + writer.addDocument(doc); + } + IndexReader reader = writer.getReader(); + writer.close(); + + IndexSearcher searcher = newSearcher(reader, true, true, true); + TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager(); + int totalHits = searcher.search(new MatchAllDocsQuery(), collectorManager); + assertEquals(5, totalHits); + + reader.close(); + indexStore.close(); + } +} diff --git a/lucene/misc/src/test/org/apache/lucene/misc/search/TestDocValuesStatsCollector.java b/lucene/misc/src/test/org/apache/lucene/misc/search/TestDocValuesStatsCollector.java index 9f9b3e7d49a..b91acf0ad25 100644 --- a/lucene/misc/src/test/org/apache/lucene/misc/search/TestDocValuesStatsCollector.java +++ b/lucene/misc/src/test/org/apache/lucene/misc/search/TestDocValuesStatsCollector.java @@ -45,8 +45,6 @@ import org.apache.lucene.misc.search.DocValuesStats.SortedLongDocValuesStats; import org.apache.lucene.misc.search.DocValuesStats.SortedSetDocValuesStats; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.search.MultiCollector; -import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -396,10 +394,8 @@ public class TestDocValuesStatsCollector extends LuceneTestCase { try (DirectoryReader reader = DirectoryReader.open(indexWriter)) { IndexSearcher searcher = new IndexSearcher(reader); SortedSetDocValuesStats stats = new SortedSetDocValuesStats(field); - TotalHitCountCollector totalHitCount = new TotalHitCountCollector(); - searcher.search( - new MatchAllDocsQuery(), - MultiCollector.wrap(totalHitCount, new DocValuesStatsCollector(stats))); + + searcher.search(new MatchAllDocsQuery(), new DocValuesStatsCollector(stats)); int expCount = (int) nonNull(docValues).count(); assertEquals(expCount, stats.count()); diff --git a/lucene/misc/src/test/org/apache/lucene/misc/search/TestMemoryAccountingBitsetCollector.java b/lucene/misc/src/test/org/apache/lucene/misc/search/TestMemoryAccountingBitsetCollector.java index 970e9bca3fa..2bdf312c7f6 100644 --- a/lucene/misc/src/test/org/apache/lucene/misc/search/TestMemoryAccountingBitsetCollector.java +++ b/lucene/misc/src/test/org/apache/lucene/misc/search/TestMemoryAccountingBitsetCollector.java @@ -24,8 +24,6 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.misc.CollectorMemoryTracker; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.search.MultiCollector; -import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; @@ -64,14 +62,12 @@ public class TestMemoryAccountingBitsetCollector extends LuceneTestCase { CollectorMemoryTracker tracker = new CollectorMemoryTracker("testMemoryTracker", perCollectorMemoryLimit); MemoryAccountingBitsetCollector bitSetCollector = new MemoryAccountingBitsetCollector(tracker); - TotalHitCountCollector hitCountCollector = new TotalHitCountCollector(); IndexSearcher searcher = new IndexSearcher(reader); expectThrows( IllegalStateException.class, () -> { - searcher.search( - new MatchAllDocsQuery(), MultiCollector.wrap(hitCountCollector, bitSetCollector)); + searcher.search(new MatchAllDocsQuery(), bitSetCollector); }); } }