LUCENE-8992: Share minimum score across segment in concurrent search

This is a follow up of LUCENE-8978 that introduces shared minimum score across segment
in concurrent search for top field collectors that sort by relevance first.
This commit is contained in:
jimczi 2019-09-27 16:06:23 +02:00
parent a9cf5f6abe
commit 58fabbed2b
3 changed files with 123 additions and 71 deletions

View File

@ -598,15 +598,16 @@ public class IndexSearcher {
final int cappedNumHits = Math.min(numHits, limit); final int cappedNumHits = Math.min(numHits, limit);
final Sort rewrittenSort = sort.rewrite(this); final Sort rewrittenSort = sort.rewrite(this);
final CollectorManager<TopFieldCollector, TopFieldDocs> manager = new CollectorManager<TopFieldCollector, TopFieldDocs>() { final CollectorManager<TopFieldCollector, TopFieldDocs> manager = new CollectorManager<>() {
private final HitsThresholdChecker hitsThresholdChecker = (executor == null || leafSlices.length <= 1) ? HitsThresholdChecker.create(TOTAL_HITS_THRESHOLD) : private final HitsThresholdChecker hitsThresholdChecker = (executor == null || leafSlices.length <= 1) ? HitsThresholdChecker.create(TOTAL_HITS_THRESHOLD) :
HitsThresholdChecker.createShared(TOTAL_HITS_THRESHOLD); HitsThresholdChecker.createShared(TOTAL_HITS_THRESHOLD);
private final BottomValueChecker bottomValueChecker = (executor ==null || leafSlices.length <= 1) ? BottomValueChecker.createMaxBottomScoreChecker() : null;
@Override @Override
public TopFieldCollector newCollector() throws IOException { public TopFieldCollector newCollector() throws IOException {
// TODO: don't pay the price for accurate hit counts by default // TODO: don't pay the price for accurate hit counts by default
return TopFieldCollector.create(rewrittenSort, cappedNumHits, after, hitsThresholdChecker); return TopFieldCollector.create(rewrittenSort, cappedNumHits, after, hitsThresholdChecker, bottomValueChecker);
} }
@Override @Override

View File

@ -101,8 +101,9 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
final FieldValueHitQueue<Entry> queue; final FieldValueHitQueue<Entry> queue;
public SimpleFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, int numHits, public SimpleFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, int numHits,
HitsThresholdChecker hitsThresholdChecker) { HitsThresholdChecker hitsThresholdChecker,
super(queue, numHits, hitsThresholdChecker, sort.needsScores()); BottomValueChecker bottomValueChecker) {
super(queue, numHits, hitsThresholdChecker, sort.needsScores(), bottomValueChecker);
this.sort = sort; this.sort = sort;
this.queue = queue; this.queue = queue;
} }
@ -185,8 +186,8 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
final FieldDoc after; final FieldDoc after;
public PagingFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, FieldDoc after, int numHits, public PagingFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, FieldDoc after, int numHits,
HitsThresholdChecker hitsThresholdChecker) { HitsThresholdChecker hitsThresholdChecker, BottomValueChecker bottomValueChecker) {
super(queue, numHits, hitsThresholdChecker, sort.needsScores()); super(queue, numHits, hitsThresholdChecker, sort.needsScores(), bottomValueChecker);
this.sort = sort; this.sort = sort;
this.queue = queue; this.queue = queue;
this.after = after; this.after = after;
@ -237,7 +238,9 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
} else { } else {
collectedAllCompetitiveHits = true; collectedAllCompetitiveHits = true;
} }
} else if (totalHitsRelation == Relation.GREATER_THAN_OR_EQUAL_TO) { } else if (totalHitsRelation == Relation.EQUAL_TO) {
// we just reached totalHitsThreshold, we can start setting the min
// competitive score now
updateMinCompetitiveScore(scorer); updateMinCompetitiveScore(scorer);
} }
return; return;
@ -284,6 +287,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
final int numHits; final int numHits;
final HitsThresholdChecker hitsThresholdChecker; final HitsThresholdChecker hitsThresholdChecker;
final BottomValueChecker bottomValueChecker;
final FieldComparator.RelevanceComparator firstComparator; final FieldComparator.RelevanceComparator firstComparator;
final boolean canSetMinScore; final boolean canSetMinScore;
final int numComparators; final int numComparators;
@ -299,7 +303,8 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
// visibility, then anyone will be able to extend the class, which is not what // visibility, then anyone will be able to extend the class, which is not what
// we want. // we want.
private TopFieldCollector(FieldValueHitQueue<Entry> pq, int numHits, private TopFieldCollector(FieldValueHitQueue<Entry> pq, int numHits,
HitsThresholdChecker hitsThresholdChecker, boolean needsScores) { HitsThresholdChecker hitsThresholdChecker, boolean needsScores,
BottomValueChecker bottomValueChecker) {
super(pq); super(pq);
this.needsScores = needsScores; this.needsScores = needsScores;
this.numHits = numHits; this.numHits = numHits;
@ -318,6 +323,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
scoreMode = needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES; scoreMode = needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
canSetMinScore = false; canSetMinScore = false;
} }
this.bottomValueChecker = bottomValueChecker;
} }
@Override @Override
@ -326,10 +332,21 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
} }
protected void updateMinCompetitiveScore(Scorable scorer) throws IOException { protected void updateMinCompetitiveScore(Scorable scorer) throws IOException {
if (canSetMinScore && hitsThresholdChecker.isThresholdReached() && queueFull) { if (canSetMinScore && hitsThresholdChecker.isThresholdReached()
assert bottom != null && firstComparator != null; && (queueFull || (bottomValueChecker != null && bottomValueChecker.getBottomValue() > 0f))) {
float minScore = firstComparator.value(bottom.slot); float maxMinScore = Float.NEGATIVE_INFINITY;
scorer.setMinCompetitiveScore(minScore); if (queueFull) {
assert bottom != null && firstComparator != null;
maxMinScore = firstComparator.value(bottom.slot);
if (bottomValueChecker != null) {
bottomValueChecker.updateThreadLocalBottomValue(maxMinScore);
}
}
if (bottomValueChecker != null) {
maxMinScore = Math.max(maxMinScore, bottomValueChecker.getBottomValue());
}
assert maxMinScore > 0f;
scorer.setMinCompetitiveScore(maxMinScore);
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
} }
} }
@ -389,14 +406,14 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
throw new IllegalArgumentException("totalHitsThreshold must be >= 0, got " + totalHitsThreshold); throw new IllegalArgumentException("totalHitsThreshold must be >= 0, got " + totalHitsThreshold);
} }
return create(sort, numHits, after, HitsThresholdChecker.create(totalHitsThreshold)); return create(sort, numHits, after, HitsThresholdChecker.create(totalHitsThreshold), null);
} }
/** /**
* Same as above with an additional parameter to allow passing in the threshold checker * Same as above with additional parameters to allow passing in the threshold checker and the bottom value checker.
*/ */
static TopFieldCollector create(Sort sort, int numHits, FieldDoc after, static TopFieldCollector create(Sort sort, int numHits, FieldDoc after,
HitsThresholdChecker hitsThresholdChecker) { HitsThresholdChecker hitsThresholdChecker, BottomValueChecker bottomValueChecker) {
if (sort.fields.length == 0) { if (sort.fields.length == 0) {
throw new IllegalArgumentException("Sort must contain at least one field"); throw new IllegalArgumentException("Sort must contain at least one field");
@ -413,7 +430,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
FieldValueHitQueue<Entry> queue = FieldValueHitQueue.create(sort.fields, numHits); FieldValueHitQueue<Entry> queue = FieldValueHitQueue.create(sort.fields, numHits);
if (after == null) { if (after == null) {
return new SimpleFieldCollector(sort, queue, numHits, hitsThresholdChecker); return new SimpleFieldCollector(sort, queue, numHits, hitsThresholdChecker, bottomValueChecker);
} else { } else {
if (after.fields == null) { if (after.fields == null) {
throw new IllegalArgumentException("after.fields wasn't set; you must pass fillFields=true for the previous search"); throw new IllegalArgumentException("after.fields wasn't set; you must pass fillFields=true for the previous search");
@ -423,22 +440,24 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
throw new IllegalArgumentException("after.fields has " + after.fields.length + " values but sort has " + sort.getSort().length); throw new IllegalArgumentException("after.fields has " + after.fields.length + " values but sort has " + sort.getSort().length);
} }
return new PagingFieldCollector(sort, queue, after, numHits, hitsThresholdChecker); return new PagingFieldCollector(sort, queue, after, numHits, hitsThresholdChecker, bottomValueChecker);
} }
} }
/** /**
* Create a CollectorManager which uses a shared hit counter to maintain number of hits * Create a CollectorManager which uses a shared hit counter to maintain number of hits
* and a shared bottom value checker to propagate the minimum score accross segments if
* the primary sort is by relevancy.
*/ */
public static CollectorManager<TopFieldCollector, TopFieldDocs> createSharedManager(Sort sort, int numHits, FieldDoc after, public static CollectorManager<TopFieldCollector, TopFieldDocs> createSharedManager(Sort sort, int numHits, FieldDoc after, int totalHitsThreshold) {
int totalHitsThreshold) {
return new CollectorManager<>() { return new CollectorManager<>() {
private final HitsThresholdChecker hitsThresholdChecker = HitsThresholdChecker.createShared(totalHitsThreshold); private final HitsThresholdChecker hitsThresholdChecker = HitsThresholdChecker.createShared(totalHitsThreshold);
private final BottomValueChecker bottomValueChecker = BottomValueChecker.createMaxBottomScoreChecker();
@Override @Override
public TopFieldCollector newCollector() throws IOException { public TopFieldCollector newCollector() throws IOException {
return create(sort, numHits, after, hitsThresholdChecker); return create(sort, numHits, after, hitsThresholdChecker, bottomValueChecker);
} }
@Override @Override

View File

@ -25,19 +25,16 @@ import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LineFileDocs;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.NamedThreadFactory; import org.apache.lucene.util.NamedThreadFactory;
@ -112,30 +109,53 @@ public class TestTopDocsCollector extends LuceneTestCase {
return tdc; return tdc;
} }
private TopDocsCollector<ScoreDoc> doSearchWithThreshold(int numResults, int thresHold) throws IOException { private TopDocsCollector<ScoreDoc> doSearchWithThreshold(int numResults, int thresHold, Query q, IndexReader indexReader) throws IOException {
Query q = new MatchAllDocsQuery(); IndexSearcher searcher = new IndexSearcher(indexReader);
IndexSearcher searcher = newSearcher(reader);
TopDocsCollector<ScoreDoc> tdc = TopScoreDocCollector.create(numResults, thresHold); TopDocsCollector<ScoreDoc> tdc = TopScoreDocCollector.create(numResults, thresHold);
searcher.search(q, tdc); searcher.search(q, tdc);
return tdc; return tdc;
} }
private TopDocs doConcurrentSearchWithThreshold(int numResults, int threshold, IndexReader reader) throws IOException { private TopDocs doConcurrentSearchWithThreshold(int numResults, int threshold, Query q, IndexReader indexReader) throws IOException {
Query q = new MatchAllDocsQuery();
ExecutorService service = new ThreadPoolExecutor(4, 4, 0L, TimeUnit.MILLISECONDS, ExecutorService service = new ThreadPoolExecutor(4, 4, 0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(), new LinkedBlockingQueue<Runnable>(),
new NamedThreadFactory("TestTopDocsCollector")); new NamedThreadFactory("TestTopDocsCollector"));
IndexSearcher searcher = new IndexSearcher(reader, service); try {
IndexSearcher searcher = new IndexSearcher(indexReader, service);
CollectorManager collectorManager = TopScoreDocCollector.createSharedManager(numResults, CollectorManager collectorManager = TopScoreDocCollector.createSharedManager(numResults,
null, threshold); null, threshold);
TopDocs tdc = (TopDocs) searcher.search(q, collectorManager); return (TopDocs) searcher.search(q, collectorManager);
} finally {
service.shutdown(); service.shutdown();
}
}
private TopFieldCollector doSearchWithThreshold(int numResults, int thresHold, Query q, Sort sort, IndexReader indexReader) throws IOException {
IndexSearcher searcher = new IndexSearcher(indexReader);
TopFieldCollector tdc = TopFieldCollector.create(sort, numResults, thresHold);
searcher.search(q, tdc);
return tdc; return tdc;
} }
private TopDocs doConcurrentSearchWithThreshold(int numResults, int threshold, Query q, Sort sort, IndexReader indexReader) throws IOException {
ExecutorService service = new ThreadPoolExecutor(4, 4, 0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(),
new NamedThreadFactory("TestTopDocsCollector"));
try {
IndexSearcher searcher = new IndexSearcher(indexReader, service);
CollectorManager collectorManager = TopFieldCollector.createSharedManager(sort, numResults,
null, threshold);
TopDocs tdc = (TopDocs) searcher.search(q, collectorManager);
return tdc;
} finally {
service.shutdown();
}
}
@Override @Override
public void setUp() throws Exception { public void setUp() throws Exception {
@ -344,8 +364,8 @@ public class TestTopDocsCollector extends LuceneTestCase {
assertEquals(2, reader.leaves().size()); assertEquals(2, reader.leaves().size());
w.close(); w.close();
TopDocsCollector collector = doSearchWithThreshold(5, 10); TopDocsCollector collector = doSearchWithThreshold( 5, 10, q, reader);
TopDocs tdc = doConcurrentSearchWithThreshold(5, 10, reader); TopDocs tdc = doConcurrentSearchWithThreshold(5, 10, q, reader);
TopDocs tdc2 = collector.topDocs(); TopDocs tdc2 = collector.topDocs();
CheckHits.checkEqual(q, tdc.scoreDocs, tdc2.scoreDocs); CheckHits.checkEqual(q, tdc.scoreDocs, tdc2.scoreDocs);
@ -404,43 +424,55 @@ public class TestTopDocsCollector extends LuceneTestCase {
public void testGlobalScore() throws Exception { public void testGlobalScore() throws Exception {
Directory dir = newDirectory(); Directory dir = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), dir); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig());
try (LineFileDocs docs = new LineFileDocs(random())) { int numDocs = atLeast(1000);
int numDocs = atLeast(100); for (int i = 0; i < numDocs; ++i) {
for (int i = 0; i < numDocs; i++) { int numAs = 1 + random().nextInt(5);
writer.addDocument(docs.nextDoc()); int numBs = random().nextFloat() < 0.5f ? 0 : 1 + random().nextInt(5);
int numCs = random().nextFloat() < 0.1f ? 0 : 1 + random().nextInt(5);
Document doc = new Document();
for (int j = 0; j < numAs; ++j) {
doc.add(new StringField("f", "A", Field.Store.NO));
} }
} for (int j = 0; j < numBs; ++j) {
doc.add(new StringField("f", "B", Field.Store.NO));
IndexReader reader = writer.getReader();
writer.close();
final IndexSearcher s = newSearcher(reader);
Terms terms = MultiTerms.getTerms(reader, "body");
int termCount = 0;
TermsEnum termsEnum = terms.iterator();
while(termsEnum.next() != null) {
termCount++;
}
assertTrue(termCount > 0);
// Target ~10 terms to search:
double chance = 10.0 / termCount;
termsEnum = terms.iterator();
while(termsEnum.next() != null) {
if (random().nextDouble() <= chance) {
BytesRef term = BytesRef.deepCopyOf(termsEnum.term());
Query query = new TermQuery(new Term("body", term));
TopDocsCollector collector = doSearchWithThreshold(5, 10);
TopDocs tdc = doConcurrentSearchWithThreshold(5, 10, reader);
TopDocs tdc2 = collector.topDocs();
CheckHits.checkEqual(query, tdc.scoreDocs, tdc2.scoreDocs);
} }
for (int j = 0; j < numCs; ++j) {
doc.add(new StringField("f", "C", Field.Store.NO));
}
w.addDocument(doc);
}
IndexReader indexReader = w.getReader();
w.close();
Query[] queries = new Query[]{
new TermQuery(new Term("f", "A")),
new TermQuery(new Term("f", "B")),
new TermQuery(new Term("f", "C")),
new BooleanQuery.Builder()
.add(new TermQuery(new Term("f", "A")), BooleanClause.Occur.MUST)
.add(new TermQuery(new Term("f", "B")), BooleanClause.Occur.SHOULD)
.build()
};
for (Query query : queries) {
TopDocsCollector collector = doSearchWithThreshold(5, 10, query, indexReader);
TopDocs tdc = doConcurrentSearchWithThreshold(5, 10, query, indexReader);
TopDocs tdc2 = collector.topDocs();
assertTrue(tdc.totalHits.value > 0);
assertTrue(tdc2.totalHits.value > 0);
CheckHits.checkEqual(query, tdc.scoreDocs, tdc2.scoreDocs);
Sort sort = new Sort(new SortField[]{SortField.FIELD_SCORE, SortField.FIELD_DOC});
TopDocsCollector fieldCollector = doSearchWithThreshold(5, 10, query, sort, indexReader);
tdc = doConcurrentSearchWithThreshold(5, 10, query, sort, indexReader);
tdc2 = fieldCollector.topDocs();
assertTrue(tdc.totalHits.value > 0);
assertTrue(tdc2.totalHits.value > 0);
CheckHits.checkEqual(query, tdc.scoreDocs, tdc2.scoreDocs);
} }
reader.close(); indexReader.close();
dir.close(); dir.close();
} }