From 58fabbed2bf3745e06f490421ac67de65e99de96 Mon Sep 17 00:00:00 2001 From: jimczi Date: Fri, 27 Sep 2019 16:06:23 +0200 Subject: [PATCH] LUCENE-8992: Share minimum score across segment in concurrent search This is a follow up of LUCENE-8978 that introduces shared minimum score across segment in concurrent search for top field collectors that sort by relevance first. --- .../apache/lucene/search/IndexSearcher.java | 5 +- .../lucene/search/TopFieldCollector.java | 55 ++++--- .../lucene/search/TestTopDocsCollector.java | 134 +++++++++++------- 3 files changed, 123 insertions(+), 71 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java index e0f0cdf4214..0078c62392c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java @@ -598,15 +598,16 @@ public class IndexSearcher { final int cappedNumHits = Math.min(numHits, limit); final Sort rewrittenSort = sort.rewrite(this); - final CollectorManager manager = new CollectorManager() { + final CollectorManager manager = new CollectorManager<>() { private final HitsThresholdChecker hitsThresholdChecker = (executor == null || leafSlices.length <= 1) ? HitsThresholdChecker.create(TOTAL_HITS_THRESHOLD) : HitsThresholdChecker.createShared(TOTAL_HITS_THRESHOLD); + private final BottomValueChecker bottomValueChecker = (executor ==null || leafSlices.length <= 1) ? BottomValueChecker.createMaxBottomScoreChecker() : null; @Override public TopFieldCollector newCollector() throws IOException { // TODO: don't pay the price for accurate hit counts by default - return TopFieldCollector.create(rewrittenSort, cappedNumHits, after, hitsThresholdChecker); + return TopFieldCollector.create(rewrittenSort, cappedNumHits, after, hitsThresholdChecker, bottomValueChecker); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java b/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java index bf1c929699b..d49ecf8eac4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java @@ -101,8 +101,9 @@ public abstract class TopFieldCollector extends TopDocsCollector { final FieldValueHitQueue queue; public SimpleFieldCollector(Sort sort, FieldValueHitQueue queue, int numHits, - HitsThresholdChecker hitsThresholdChecker) { - super(queue, numHits, hitsThresholdChecker, sort.needsScores()); + HitsThresholdChecker hitsThresholdChecker, + BottomValueChecker bottomValueChecker) { + super(queue, numHits, hitsThresholdChecker, sort.needsScores(), bottomValueChecker); this.sort = sort; this.queue = queue; } @@ -185,8 +186,8 @@ public abstract class TopFieldCollector extends TopDocsCollector { final FieldDoc after; public PagingFieldCollector(Sort sort, FieldValueHitQueue queue, FieldDoc after, int numHits, - HitsThresholdChecker hitsThresholdChecker) { - super(queue, numHits, hitsThresholdChecker, sort.needsScores()); + HitsThresholdChecker hitsThresholdChecker, BottomValueChecker bottomValueChecker) { + super(queue, numHits, hitsThresholdChecker, sort.needsScores(), bottomValueChecker); this.sort = sort; this.queue = queue; this.after = after; @@ -237,7 +238,9 @@ public abstract class TopFieldCollector extends TopDocsCollector { } else { collectedAllCompetitiveHits = true; } - } else if (totalHitsRelation == Relation.GREATER_THAN_OR_EQUAL_TO) { + } else if (totalHitsRelation == Relation.EQUAL_TO) { + // we just reached totalHitsThreshold, we can start setting the min + // competitive score now updateMinCompetitiveScore(scorer); } return; @@ -284,6 +287,7 @@ public abstract class TopFieldCollector extends TopDocsCollector { final int numHits; final HitsThresholdChecker hitsThresholdChecker; + final BottomValueChecker bottomValueChecker; final FieldComparator.RelevanceComparator firstComparator; final boolean canSetMinScore; final int numComparators; @@ -299,7 +303,8 @@ public abstract class TopFieldCollector extends TopDocsCollector { // visibility, then anyone will be able to extend the class, which is not what // we want. private TopFieldCollector(FieldValueHitQueue pq, int numHits, - HitsThresholdChecker hitsThresholdChecker, boolean needsScores) { + HitsThresholdChecker hitsThresholdChecker, boolean needsScores, + BottomValueChecker bottomValueChecker) { super(pq); this.needsScores = needsScores; this.numHits = numHits; @@ -318,6 +323,7 @@ public abstract class TopFieldCollector extends TopDocsCollector { scoreMode = needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES; canSetMinScore = false; } + this.bottomValueChecker = bottomValueChecker; } @Override @@ -326,10 +332,21 @@ public abstract class TopFieldCollector extends TopDocsCollector { } protected void updateMinCompetitiveScore(Scorable scorer) throws IOException { - if (canSetMinScore && hitsThresholdChecker.isThresholdReached() && queueFull) { - assert bottom != null && firstComparator != null; - float minScore = firstComparator.value(bottom.slot); - scorer.setMinCompetitiveScore(minScore); + if (canSetMinScore && hitsThresholdChecker.isThresholdReached() + && (queueFull || (bottomValueChecker != null && bottomValueChecker.getBottomValue() > 0f))) { + float maxMinScore = Float.NEGATIVE_INFINITY; + if (queueFull) { + assert bottom != null && firstComparator != null; + maxMinScore = firstComparator.value(bottom.slot); + if (bottomValueChecker != null) { + bottomValueChecker.updateThreadLocalBottomValue(maxMinScore); + } + } + if (bottomValueChecker != null) { + maxMinScore = Math.max(maxMinScore, bottomValueChecker.getBottomValue()); + } + assert maxMinScore > 0f; + scorer.setMinCompetitiveScore(maxMinScore); totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; } } @@ -389,14 +406,14 @@ public abstract class TopFieldCollector extends TopDocsCollector { throw new IllegalArgumentException("totalHitsThreshold must be >= 0, got " + totalHitsThreshold); } - return create(sort, numHits, after, HitsThresholdChecker.create(totalHitsThreshold)); + return create(sort, numHits, after, HitsThresholdChecker.create(totalHitsThreshold), null); } /** - * Same as above with an additional parameter to allow passing in the threshold checker + * Same as above with additional parameters to allow passing in the threshold checker and the bottom value checker. */ static TopFieldCollector create(Sort sort, int numHits, FieldDoc after, - HitsThresholdChecker hitsThresholdChecker) { + HitsThresholdChecker hitsThresholdChecker, BottomValueChecker bottomValueChecker) { if (sort.fields.length == 0) { throw new IllegalArgumentException("Sort must contain at least one field"); @@ -413,7 +430,7 @@ public abstract class TopFieldCollector extends TopDocsCollector { FieldValueHitQueue queue = FieldValueHitQueue.create(sort.fields, numHits); if (after == null) { - return new SimpleFieldCollector(sort, queue, numHits, hitsThresholdChecker); + return new SimpleFieldCollector(sort, queue, numHits, hitsThresholdChecker, bottomValueChecker); } else { if (after.fields == null) { throw new IllegalArgumentException("after.fields wasn't set; you must pass fillFields=true for the previous search"); @@ -423,22 +440,24 @@ public abstract class TopFieldCollector extends TopDocsCollector { throw new IllegalArgumentException("after.fields has " + after.fields.length + " values but sort has " + sort.getSort().length); } - return new PagingFieldCollector(sort, queue, after, numHits, hitsThresholdChecker); + return new PagingFieldCollector(sort, queue, after, numHits, hitsThresholdChecker, bottomValueChecker); } } /** * Create a CollectorManager which uses a shared hit counter to maintain number of hits + * and a shared bottom value checker to propagate the minimum score accross segments if + * the primary sort is by relevancy. */ - public static CollectorManager createSharedManager(Sort sort, int numHits, FieldDoc after, - int totalHitsThreshold) { + public static CollectorManager createSharedManager(Sort sort, int numHits, FieldDoc after, int totalHitsThreshold) { return new CollectorManager<>() { private final HitsThresholdChecker hitsThresholdChecker = HitsThresholdChecker.createShared(totalHitsThreshold); + private final BottomValueChecker bottomValueChecker = BottomValueChecker.createMaxBottomScoreChecker(); @Override public TopFieldCollector newCollector() throws IOException { - return create(sort, numHits, after, hitsThresholdChecker); + return create(sort, numHits, after, hitsThresholdChecker, bottomValueChecker); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTopDocsCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestTopDocsCollector.java index 130449bbbd0..e0818e2f120 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTopDocsCollector.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTopDocsCollector.java @@ -25,19 +25,16 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.StringField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.MultiTerms; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.Term; -import org.apache.lucene.index.Terms; -import org.apache.lucene.index.TermsEnum; import org.apache.lucene.store.Directory; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.LineFileDocs; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.NamedThreadFactory; @@ -112,30 +109,53 @@ public class TestTopDocsCollector extends LuceneTestCase { return tdc; } - private TopDocsCollector doSearchWithThreshold(int numResults, int thresHold) throws IOException { - Query q = new MatchAllDocsQuery(); - IndexSearcher searcher = newSearcher(reader); + private TopDocsCollector doSearchWithThreshold(int numResults, int thresHold, Query q, IndexReader indexReader) throws IOException { + IndexSearcher searcher = new IndexSearcher(indexReader); TopDocsCollector tdc = TopScoreDocCollector.create(numResults, thresHold); searcher.search(q, tdc); return tdc; } - private TopDocs doConcurrentSearchWithThreshold(int numResults, int threshold, IndexReader reader) throws IOException { - Query q = new MatchAllDocsQuery(); + private TopDocs doConcurrentSearchWithThreshold(int numResults, int threshold, Query q, IndexReader indexReader) throws IOException { ExecutorService service = new ThreadPoolExecutor(4, 4, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue(), new NamedThreadFactory("TestTopDocsCollector")); - IndexSearcher searcher = new IndexSearcher(reader, service); + try { + IndexSearcher searcher = new IndexSearcher(indexReader, service); - CollectorManager collectorManager = TopScoreDocCollector.createSharedManager(numResults, - null, threshold); + CollectorManager collectorManager = TopScoreDocCollector.createSharedManager(numResults, + null, threshold); - TopDocs tdc = (TopDocs) searcher.search(q, collectorManager); - - service.shutdown(); + return (TopDocs) searcher.search(q, collectorManager); + } finally { + service.shutdown(); + } + } + private TopFieldCollector doSearchWithThreshold(int numResults, int thresHold, Query q, Sort sort, IndexReader indexReader) throws IOException { + IndexSearcher searcher = new IndexSearcher(indexReader); + TopFieldCollector tdc = TopFieldCollector.create(sort, numResults, thresHold); + searcher.search(q, tdc); return tdc; } + + private TopDocs doConcurrentSearchWithThreshold(int numResults, int threshold, Query q, Sort sort, IndexReader indexReader) throws IOException { + ExecutorService service = new ThreadPoolExecutor(4, 4, 0L, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue(), + new NamedThreadFactory("TestTopDocsCollector")); + try { + IndexSearcher searcher = new IndexSearcher(indexReader, service); + + CollectorManager collectorManager = TopFieldCollector.createSharedManager(sort, numResults, + null, threshold); + + TopDocs tdc = (TopDocs) searcher.search(q, collectorManager); + + return tdc; + } finally { + service.shutdown(); + } + } @Override public void setUp() throws Exception { @@ -344,8 +364,8 @@ public class TestTopDocsCollector extends LuceneTestCase { assertEquals(2, reader.leaves().size()); w.close(); - TopDocsCollector collector = doSearchWithThreshold(5, 10); - TopDocs tdc = doConcurrentSearchWithThreshold(5, 10, reader); + TopDocsCollector collector = doSearchWithThreshold( 5, 10, q, reader); + TopDocs tdc = doConcurrentSearchWithThreshold(5, 10, q, reader); TopDocs tdc2 = collector.topDocs(); CheckHits.checkEqual(q, tdc.scoreDocs, tdc2.scoreDocs); @@ -404,43 +424,55 @@ public class TestTopDocsCollector extends LuceneTestCase { public void testGlobalScore() throws Exception { Directory dir = newDirectory(); - RandomIndexWriter writer = new RandomIndexWriter(random(), dir); - try (LineFileDocs docs = new LineFileDocs(random())) { - int numDocs = atLeast(100); - for (int i = 0; i < numDocs; i++) { - writer.addDocument(docs.nextDoc()); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig()); + int numDocs = atLeast(1000); + for (int i = 0; i < numDocs; ++i) { + int numAs = 1 + random().nextInt(5); + int numBs = random().nextFloat() < 0.5f ? 0 : 1 + random().nextInt(5); + int numCs = random().nextFloat() < 0.1f ? 0 : 1 + random().nextInt(5); + Document doc = new Document(); + for (int j = 0; j < numAs; ++j) { + doc.add(new StringField("f", "A", Field.Store.NO)); } - } - - IndexReader reader = writer.getReader(); - writer.close(); - - final IndexSearcher s = newSearcher(reader); - Terms terms = MultiTerms.getTerms(reader, "body"); - int termCount = 0; - TermsEnum termsEnum = terms.iterator(); - while(termsEnum.next() != null) { - termCount++; - } - assertTrue(termCount > 0); - - // Target ~10 terms to search: - double chance = 10.0 / termCount; - termsEnum = terms.iterator(); - while(termsEnum.next() != null) { - if (random().nextDouble() <= chance) { - BytesRef term = BytesRef.deepCopyOf(termsEnum.term()); - Query query = new TermQuery(new Term("body", term)); - - TopDocsCollector collector = doSearchWithThreshold(5, 10); - TopDocs tdc = doConcurrentSearchWithThreshold(5, 10, reader); - TopDocs tdc2 = collector.topDocs(); - - CheckHits.checkEqual(query, tdc.scoreDocs, tdc2.scoreDocs); + for (int j = 0; j < numBs; ++j) { + doc.add(new StringField("f", "B", Field.Store.NO)); } + for (int j = 0; j < numCs; ++j) { + doc.add(new StringField("f", "C", Field.Store.NO)); + } + w.addDocument(doc); + } + IndexReader indexReader = w.getReader(); + w.close(); + Query[] queries = new Query[]{ + new TermQuery(new Term("f", "A")), + new TermQuery(new Term("f", "B")), + new TermQuery(new Term("f", "C")), + new BooleanQuery.Builder() + .add(new TermQuery(new Term("f", "A")), BooleanClause.Occur.MUST) + .add(new TermQuery(new Term("f", "B")), BooleanClause.Occur.SHOULD) + .build() + }; + for (Query query : queries) { + TopDocsCollector collector = doSearchWithThreshold(5, 10, query, indexReader); + TopDocs tdc = doConcurrentSearchWithThreshold(5, 10, query, indexReader); + TopDocs tdc2 = collector.topDocs(); + + assertTrue(tdc.totalHits.value > 0); + assertTrue(tdc2.totalHits.value > 0); + CheckHits.checkEqual(query, tdc.scoreDocs, tdc2.scoreDocs); + + Sort sort = new Sort(new SortField[]{SortField.FIELD_SCORE, SortField.FIELD_DOC}); + TopDocsCollector fieldCollector = doSearchWithThreshold(5, 10, query, sort, indexReader); + tdc = doConcurrentSearchWithThreshold(5, 10, query, sort, indexReader); + tdc2 = fieldCollector.topDocs(); + + assertTrue(tdc.totalHits.value > 0); + assertTrue(tdc2.totalHits.value > 0); + CheckHits.checkEqual(query, tdc.scoreDocs, tdc2.scoreDocs); } - reader.close(); + indexReader.close(); dir.close(); }