LUCENE-10395: Introduce TotalHitCountCollectorManager (#622)

This commit is contained in:
Luca Cavanna 2022-01-31 14:45:35 +01:00 committed by GitHub
parent 933c54fe87
commit df12e2b195
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 141 additions and 38 deletions

View File

@ -124,6 +124,10 @@ New Features
* LUCENE-10250: Add support for arbitrary length hierarchical SSDV facets. (Marc D'mello) * LUCENE-10250: Add support for arbitrary length hierarchical SSDV facets. (Marc D'mello)
* LUCENE-10395: Add support for TotalHitCountCollectorManager, a collector manager
based on TotalHitCountCollector that allows users to parallelize counting the
number of hits. (Luca Cavanna, Adrien Grand)
Improvements Improvements
--------------------- ---------------------

View File

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Collection;
/**
* Collector manager based on {@link TotalHitCountCollector} that allows users to parallelize
* counting the number of hits, expected to be used mostly wrapped in {@link MultiCollectorManager}.
* For cases when this is the only collector manager used, {@link IndexSearcher#count(Query)} should
* be called instead of {@link IndexSearcher#search(Query, CollectorManager)} as the former is
* faster whenever the count can be returned directly from the index statistics.
*/
public class TotalHitCountCollectorManager
implements CollectorManager<TotalHitCountCollector, Integer> {
@Override
public TotalHitCountCollector newCollector() throws IOException {
return new TotalHitCountCollector();
}
@Override
public Integer reduce(Collection<TotalHitCountCollector> collectors) throws IOException {
int totalHits = 0;
for (TotalHitCountCollector collector : collectors) {
totalHits += collector.getTotalHits();
}
return totalHits;
}
}

View File

@ -39,6 +39,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.LongPoint; import org.apache.lucene.document.LongPoint;
@ -167,19 +168,36 @@ public class TestLRUQueryCache extends LuceneTestCase {
RandomPicks.randomFrom( RandomPicks.randomFrom(
random(), new String[] {"blue", "red", "yellow", "green"}); random(), new String[] {"blue", "red", "yellow", "green"});
final Query q = new TermQuery(new Term("color", value)); final Query q = new TermQuery(new Term("color", value));
TotalHitCountCollector collector = new TotalHitCountCollector(); TotalHitCountCollectorManager collectorManager =
searcher.search(q, collector); // will use the cache new TotalHitCountCollectorManager();
final int totalHits1 = collector.getTotalHits(); // will use the cache
TotalHitCountCollector collector2 = new TotalHitCountCollector(); final int totalHits1 = searcher.search(q, collectorManager);
searcher.search( final long totalHits2 =
q, searcher.search(
new FilterCollector(collector2) { q,
@Override new CollectorManager<FilterCollector, Integer>() {
public ScoreMode scoreMode() { @Override
return ScoreMode.COMPLETE; // will not use the cache because of scores public FilterCollector newCollector() {
} return new FilterCollector(new TotalHitCountCollector()) {
}); @Override
final long totalHits2 = collector2.getTotalHits(); public ScoreMode scoreMode() {
// will not use the cache because of scores
return ScoreMode.COMPLETE;
}
};
}
@Override
public Integer reduce(Collection<FilterCollector> collectors)
throws IOException {
return collectorManager.reduce(
collectors.stream()
.map(
filterCollector ->
(TotalHitCountCollector) filterCollector.in)
.collect(Collectors.toList()));
}
});
assertEquals(totalHits2, totalHits1); assertEquals(totalHits2, totalHits1);
} finally { } finally {
mgr.release(searcher); mgr.release(searcher);
@ -945,7 +963,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
searcher.setQueryCache(queryCache); searcher.setQueryCache(queryCache);
searcher.setQueryCachingPolicy(policy); searcher.setQueryCachingPolicy(policy);
searcher.search(query.build(), new TotalHitCountCollector()); searcher.search(query.build(), new TotalHitCountCollectorManager());
reader.close(); reader.close();
dir.close(); dir.close();
@ -1174,7 +1192,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
try { try {
// trigger an eviction // trigger an eviction
searcher.search(new MatchAllDocsQuery(), new TotalHitCountCollector()); searcher.search(new MatchAllDocsQuery(), new TotalHitCountCollectorManager());
fail(); fail();
} catch ( } catch (
@SuppressWarnings("unused") @SuppressWarnings("unused")

View File

@ -57,12 +57,11 @@ public class TestSearchWithThreads extends LuceneTestCase {
final AtomicBoolean failed = new AtomicBoolean(); final AtomicBoolean failed = new AtomicBoolean();
final AtomicLong netSearch = new AtomicLong(); final AtomicLong netSearch = new AtomicLong();
TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
Thread[] threads = new Thread[numThreads]; Thread[] threads = new Thread[numThreads];
for (int threadID = 0; threadID < numThreads; threadID++) { for (int threadID = 0; threadID < numThreads; threadID++) {
threads[threadID] = threads[threadID] =
new Thread() { new Thread() {
TotalHitCountCollector col = new TotalHitCountCollector();
@Override @Override
public void run() { public void run() {
@ -70,10 +69,8 @@ public class TestSearchWithThreads extends LuceneTestCase {
long totHits = 0; long totHits = 0;
long totSearch = 0; long totSearch = 0;
for (; totSearch < numSearches & !failed.get(); totSearch++) { for (; totSearch < numSearches & !failed.get(); totSearch++) {
s.search(new TermQuery(new Term("body", "aaa")), col); totHits += s.search(new TermQuery(new Term("body", "aaa")), collectorManager);
totHits += col.getTotalHits(); totHits += s.search(new TermQuery(new Term("body", "bbb")), collectorManager);
s.search(new TermQuery(new Term("body", "bbb")), col);
totHits += col.getTotalHits();
} }
assertTrue(totSearch > 0 && totHits > 0); assertTrue(totSearch > 0 && totHits > 0);
netSearch.addAndGet(totSearch); netSearch.addAndGet(totSearch);

View File

@ -91,16 +91,15 @@ public class TestTermQuery extends LuceneTestCase {
IndexSearcher searcher = new IndexSearcher(reader); IndexSearcher searcher = new IndexSearcher(reader);
// use a collector rather than searcher.count() which would just read the // use a collector rather than searcher.count() which would just read the
// doc freq instead of creating a scorer // doc freq instead of creating a scorer
TotalHitCountCollector collector = new TotalHitCountCollector(); TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
searcher.search(query, collector); int totalHits = searcher.search(query, collectorManager);
assertEquals(1, collector.getTotalHits()); assertEquals(1, totalHits);
TermQuery queryWithContext = TermQuery queryWithContext =
new TermQuery( new TermQuery(
new Term("foo", "bar"), new Term("foo", "bar"),
TermStates.build(reader.getContext(), new Term("foo", "bar"), true)); TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
collector = new TotalHitCountCollector(); totalHits = searcher.search(queryWithContext, collectorManager);
searcher.search(queryWithContext, collector); assertEquals(1, totalHits);
assertEquals(1, collector.getTotalHits());
IOUtils.close(reader, w, dir); IOUtils.close(reader, w, dir);
} }

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestTotalHitCountCollectorManager extends LuceneTestCase {
public void testBasics() throws Exception {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
for (int i = 0; i < 5; i++) {
Document doc = new Document();
doc.add(new StringField("string", "a" + i, Field.Store.NO));
doc.add(new StringField("string", "b" + i, Field.Store.NO));
writer.addDocument(doc);
}
IndexReader reader = writer.getReader();
writer.close();
IndexSearcher searcher = newSearcher(reader, true, true, true);
TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
int totalHits = searcher.search(new MatchAllDocsQuery(), collectorManager);
assertEquals(5, totalHits);
reader.close();
indexStore.close();
}
}

View File

@ -45,8 +45,6 @@ import org.apache.lucene.misc.search.DocValuesStats.SortedLongDocValuesStats;
import org.apache.lucene.misc.search.DocValuesStats.SortedSetDocValuesStats; import org.apache.lucene.misc.search.DocValuesStats.SortedSetDocValuesStats;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
@ -396,10 +394,8 @@ public class TestDocValuesStatsCollector extends LuceneTestCase {
try (DirectoryReader reader = DirectoryReader.open(indexWriter)) { try (DirectoryReader reader = DirectoryReader.open(indexWriter)) {
IndexSearcher searcher = new IndexSearcher(reader); IndexSearcher searcher = new IndexSearcher(reader);
SortedSetDocValuesStats stats = new SortedSetDocValuesStats(field); SortedSetDocValuesStats stats = new SortedSetDocValuesStats(field);
TotalHitCountCollector totalHitCount = new TotalHitCountCollector();
searcher.search( searcher.search(new MatchAllDocsQuery(), new DocValuesStatsCollector(stats));
new MatchAllDocsQuery(),
MultiCollector.wrap(totalHitCount, new DocValuesStatsCollector(stats)));
int expCount = (int) nonNull(docValues).count(); int expCount = (int) nonNull(docValues).count();
assertEquals(expCount, stats.count()); assertEquals(expCount, stats.count());

View File

@ -24,8 +24,6 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.misc.CollectorMemoryTracker; import org.apache.lucene.misc.CollectorMemoryTracker;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
@ -64,14 +62,12 @@ public class TestMemoryAccountingBitsetCollector extends LuceneTestCase {
CollectorMemoryTracker tracker = CollectorMemoryTracker tracker =
new CollectorMemoryTracker("testMemoryTracker", perCollectorMemoryLimit); new CollectorMemoryTracker("testMemoryTracker", perCollectorMemoryLimit);
MemoryAccountingBitsetCollector bitSetCollector = new MemoryAccountingBitsetCollector(tracker); MemoryAccountingBitsetCollector bitSetCollector = new MemoryAccountingBitsetCollector(tracker);
TotalHitCountCollector hitCountCollector = new TotalHitCountCollector();
IndexSearcher searcher = new IndexSearcher(reader); IndexSearcher searcher = new IndexSearcher(reader);
expectThrows( expectThrows(
IllegalStateException.class, IllegalStateException.class,
() -> { () -> {
searcher.search( searcher.search(new MatchAllDocsQuery(), bitSetCollector);
new MatchAllDocsQuery(), MultiCollector.wrap(hitCountCollector, bitSetCollector));
}); });
} }
} }