LUCENE-10208: Ensure that the minimum competitive score does not decrease in concurrent search (#431)

Co-authored-by: Adrien Grand <jpountz@gmail.com>
This commit is contained in:
Jim Ferenczi 2021-11-09 11:04:17 +01:00 committed by GitHub
parent 263765a9b0
commit 94b66c0ed2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 105 additions and 94 deletions

View File

@ -548,6 +548,8 @@ Bug Fixes
* LUCENE-10154: NumericLeafComparator to define getPointValues. (Mayya Sharipova, Adrien Grand)
* LUCENE-10208: Ensure that the minimum competitive score does not decrease in concurrent search. (Jim Ferenczi, Adrien Grand)
Build
---------------------

View File

@ -26,7 +26,7 @@ final class MaxScoreAccumulator {
static final int DEFAULT_INTERVAL = 0x3ff;
// scores are always positive
final LongAccumulator acc = new LongAccumulator(Long::max, Long.MIN_VALUE);
final LongAccumulator acc = new LongAccumulator(MaxScoreAccumulator::maxEncode, Long.MIN_VALUE);
// non-final and visible for tests
long modInterval;
@ -35,9 +35,26 @@ final class MaxScoreAccumulator {
this.modInterval = DEFAULT_INTERVAL;
}
void accumulate(int docID, float score) {
assert docID >= 0 && score >= 0;
long encode = (((long) Float.floatToIntBits(score)) << 32) | docID;
/**
* Return the max encoded DocAndScore in a way that is consistent with {@link
* DocAndScore#compareTo}.
*/
private static long maxEncode(long v1, long v2) {
float score1 = Float.intBitsToFloat((int) (v1 >> 32));
float score2 = Float.intBitsToFloat((int) (v2 >> 32));
int cmp = Float.compare(score1, score2);
if (cmp == 0) {
// tie-break on the minimum doc base
return (int) v1 < (int) v2 ? v1 : v2;
} else if (cmp > 0) {
return v1;
}
return v2;
}
void accumulate(int docBase, float score) {
assert docBase >= 0 && score >= 0;
long encode = (((long) Float.floatToIntBits(score)) << 32) | docBase;
acc.accumulate(encode);
}
@ -47,16 +64,16 @@ final class MaxScoreAccumulator {
return null;
}
float score = Float.intBitsToFloat((int) (value >> 32));
int docID = (int) value;
return new DocAndScore(docID, score);
int docBase = (int) value;
return new DocAndScore(docBase, score);
}
static class DocAndScore implements Comparable<DocAndScore> {
final int docID;
final int docBase;
final float score;
DocAndScore(int docID, float score) {
this.docID = docID;
DocAndScore(int docBase, float score) {
this.docBase = docBase;
this.score = score;
}
@ -64,7 +81,14 @@ final class MaxScoreAccumulator {
public int compareTo(DocAndScore o) {
int cmp = Float.compare(score, o.score);
if (cmp == 0) {
return Integer.compare(docID, o.docID);
// tie-break on the minimum doc base
// For a given minimum competitive score, we want to know the first segment
// where this score occurred, hence the reverse order here.
// On segments with a lower docBase, any document whose score is greater
// than or equal to this score would be competitive, while on segments with a
// higher docBase, documents need to have a strictly greater score to be
// competitive since we tie break on doc ID.
return Integer.compare(o.docBase, docBase);
}
return cmp;
}
@ -74,17 +98,17 @@ final class MaxScoreAccumulator {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DocAndScore result = (DocAndScore) o;
return docID == result.docID && Float.compare(result.score, score) == 0;
return docBase == result.docBase && Float.compare(result.score, score) == 0;
}
@Override
public int hashCode() {
return Objects.hash(docID, score);
return Objects.hash(docBase, score);
}
@Override
public String toString() {
return "DocAndScore{" + "docID=" + docID + ", score=" + score + '}';
return "DocAndScore{" + "docBase=" + docBase + ", score=" + score + '}';
}
}
}

View File

@ -134,9 +134,9 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
comparator.setScorer(scorer);
minCompetitiveScore = 0f;
updateMinCompetitiveScore(scorer);
if (minScoreAcc != null) {
if (minScoreAcc == null) {
updateMinCompetitiveScore(scorer);
} else {
updateGlobalMinCompetitiveScore(scorer);
}
}
@ -191,6 +191,8 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
// reset the minimum competitive score
minCompetitiveScore = 0f;
docBase = context.docBase;
return new TopFieldLeafCollector(queue, sort, context) {
@ -244,6 +246,8 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
// reset the minimum competitive score
minCompetitiveScore = 0f;
docBase = context.docBase;
final int afterDoc = after.doc - docBase;
@ -363,7 +367,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
minCompetitiveScore = minScore;
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
if (minScoreAcc != null) {
minScoreAcc.accumulate(bottom.doc, minScore);
minScoreAcc.accumulate(docBase, minScore);
}
}
}

View File

@ -55,14 +55,15 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
// reset the minimum competitive score
docBase = context.docBase;
return new ScorerLeafCollector() {
minCompetitiveScore = 0f;
return new ScorerLeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(scorer);
minCompetitiveScore = 0f;
updateMinCompetitiveScore(scorer);
if (minScoreAcc != null) {
if (minScoreAcc == null) {
updateMinCompetitiveScore(scorer);
} else {
updateGlobalMinCompetitiveScore(scorer);
}
}
@ -132,8 +133,19 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
docBase = context.docBase;
final int afterDoc = after.doc - context.docBase;
minCompetitiveScore = 0f;
return new ScorerLeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(scorer);
if (minScoreAcc == null) {
updateMinCompetitiveScore(scorer);
} else {
updateGlobalMinCompetitiveScore(scorer);
}
}
@Override
public void collect(int doc) throws IOException {
float score = scorer.score();
@ -307,7 +319,7 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
// the next float if the global minimum score is set on a document id that is
// smaller than the ids in the current leaf
float score =
docBase > maxMinScore.docID ? Math.nextUp(maxMinScore.score) : maxMinScore.score;
docBase >= maxMinScore.docBase ? Math.nextUp(maxMinScore.score) : maxMinScore.score;
if (score > minCompetitiveScore) {
assert hitsThresholdChecker.isThresholdReached();
scorer.setMinCompetitiveScore(score);
@ -332,7 +344,7 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
// we don't use the next float but we register the document
// id so that other leaves can require it if they are after
// the current maximum
minScoreAcc.accumulate(pqTop.doc, pqTop.score);
minScoreAcc.accumulate(docBase, pqTop.score);
}
}
}

View File

@ -23,21 +23,29 @@ public class TestMaxScoreAccumulator extends LuceneTestCase {
public void testSimple() {
MaxScoreAccumulator acc = new MaxScoreAccumulator();
acc.accumulate(0, 0f);
assertEquals(0f, acc.get().score, 0);
assertEquals(0, acc.get().docBase, 0);
acc.accumulate(10, 0f);
assertEquals(0f, acc.get().score, 0);
assertEquals(10, acc.get().docID, 0);
assertEquals(0, acc.get().docBase, 0);
acc.accumulate(100, 1000f);
assertEquals(1000f, acc.get().score, 0);
assertEquals(100, acc.get().docID, 0);
assertEquals(100, acc.get().docBase, 0);
acc.accumulate(1000, 5f);
assertEquals(1000f, acc.get().score, 0);
assertEquals(100, acc.get().docID, 0);
assertEquals(100, acc.get().docBase, 0);
acc.accumulate(99, 1000f);
assertEquals(1000f, acc.get().score, 0);
assertEquals(100, acc.get().docID, 0);
acc.accumulate(0, 1001f);
assertEquals(99, acc.get().docBase, 0);
acc.accumulate(1000, 1001f);
assertEquals(1001f, acc.get().score, 0);
assertEquals(0, acc.get().docID, 0);
assertEquals(1000, acc.get().docBase, 0);
acc.accumulate(10, 1001f);
assertEquals(1001f, acc.get().score, 0);
assertEquals(10, acc.get().docBase, 0);
acc.accumulate(100, 1001f);
assertEquals(1001f, acc.get().score, 0);
assertEquals(10, acc.get().docBase, 0);
}
public void testRandom() {
@ -48,7 +56,7 @@ public class TestMaxScoreAccumulator extends LuceneTestCase {
for (int i = 0; i < numDocs; i++) {
MaxScoreAccumulator.DocAndScore res =
new MaxScoreAccumulator.DocAndScore(random().nextInt(maxDocs), random().nextFloat());
acc.accumulate(res.docID, res.score);
acc.accumulate(res.docBase, res.score);
if (res.compareTo(max) > 0) {
max = res;
}

View File

@ -18,10 +18,6 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.Field.Store;
@ -41,7 +37,6 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LineFileDocs;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.NamedThreadFactory;
public class TestTopDocsCollector extends LuceneTestCase {
@ -141,7 +136,7 @@ public class TestTopDocsCollector extends LuceneTestCase {
private TopDocsCollector<ScoreDoc> doSearchWithThreshold(
int numResults, int thresHold, Query q, IndexReader indexReader) throws IOException {
IndexSearcher searcher = new IndexSearcher(indexReader);
IndexSearcher searcher = newSearcher(indexReader, true, true, false);
TopDocsCollector<ScoreDoc> tdc = TopScoreDocCollector.create(numResults, thresHold);
searcher.search(q, tdc);
return tdc;
@ -149,24 +144,10 @@ public class TestTopDocsCollector extends LuceneTestCase {
private TopDocs doConcurrentSearchWithThreshold(
int numResults, int threshold, Query q, IndexReader indexReader) throws IOException {
ExecutorService service =
new ThreadPoolExecutor(
4,
4,
0L,
TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(),
new NamedThreadFactory("TestTopDocsCollector"));
try {
IndexSearcher searcher = new IndexSearcher(indexReader, service);
CollectorManager<TopScoreDocCollector, TopDocs> collectorManager =
TopScoreDocCollector.createSharedManager(numResults, null, threshold);
return searcher.search(q, collectorManager);
} finally {
service.shutdown();
}
IndexSearcher searcher = newSearcher(indexReader, true, true, true);
CollectorManager<TopScoreDocCollector, TopDocs> collectorManager =
TopScoreDocCollector.createSharedManager(numResults, null, threshold);
return searcher.search(q, collectorManager);
}
@Override
@ -303,8 +284,9 @@ public class TestTopDocsCollector extends LuceneTestCase {
Float minCompetitiveScore = null;
@Override
public void setMinCompetitiveScore(float minCompetitiveScore) {
this.minCompetitiveScore = minCompetitiveScore;
public void setMinCompetitiveScore(float score) {
assert minCompetitiveScore == null || score >= minCompetitiveScore;
this.minCompetitiveScore = score;
}
@Override
@ -356,9 +338,9 @@ public class TestTopDocsCollector extends LuceneTestCase {
scorer.doc = 3;
scorer.score = 0.5f;
// Make sure we do not call setMinCompetitiveScore for non-competitive hits
scorer.minCompetitiveScore = Float.NaN;
scorer.minCompetitiveScore = null;
leafCollector.collect(3);
assertTrue(Float.isNaN(scorer.minCompetitiveScore));
assertNull(scorer.minCompetitiveScore);
scorer.doc = 4;
scorer.score = 4;
@ -613,6 +595,10 @@ public class TestTopDocsCollector extends LuceneTestCase {
assertEquals(11, topDocs.totalHits.value);
assertEquals(new TotalHits(11, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), topDocs.totalHits);
leafCollector.setScorer(scorer);
leafCollector2.setScorer(scorer2);
leafCollector3.setScorer(scorer3);
reader.close();
dir.close();
}

View File

@ -21,10 +21,6 @@ import static org.apache.lucene.search.SortField.FIELD_SCORE;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.Field.Store;
@ -42,7 +38,6 @@ import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.FieldValueHitQueue.Entry;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.TestUtil;
public class TestTopFieldCollector extends LuceneTestCase {
@ -75,7 +70,7 @@ public class TestTopFieldCollector extends LuceneTestCase {
private TopFieldCollector doSearchWithThreshold(
int numResults, int thresHold, Query q, Sort sort, IndexReader indexReader)
throws IOException {
IndexSearcher searcher = new IndexSearcher(indexReader);
IndexSearcher searcher = newSearcher(indexReader);
TopFieldCollector tdc = TopFieldCollector.create(sort, numResults, thresHold);
searcher.search(q, tdc);
return tdc;
@ -84,26 +79,14 @@ public class TestTopFieldCollector extends LuceneTestCase {
private TopDocs doConcurrentSearchWithThreshold(
int numResults, int threshold, Query q, Sort sort, IndexReader indexReader)
throws IOException {
ExecutorService service =
new ThreadPoolExecutor(
4,
4,
0L,
TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(),
new NamedThreadFactory("TestTopDocsCollector"));
try {
IndexSearcher searcher = new IndexSearcher(indexReader, service);
IndexSearcher searcher = newSearcher(indexReader, true, true, true);
CollectorManager<TopFieldCollector, TopFieldDocs> collectorManager =
TopFieldCollector.createSharedManager(sort, numResults, null, threshold);
CollectorManager<TopFieldCollector, TopFieldDocs> collectorManager =
TopFieldCollector.createSharedManager(sort, numResults, null, threshold);
TopDocs tdc = searcher.search(q, collectorManager);
TopDocs tdc = searcher.search(q, collectorManager);
return tdc;
} finally {
service.shutdown();
}
return tdc;
}
public void testSortWithoutFillFields() throws Exception {
@ -146,17 +129,7 @@ public class TestTopFieldCollector extends LuceneTestCase {
}
public void testSharedHitcountCollector() throws Exception {
ExecutorService service =
new ThreadPoolExecutor(
4,
4,
0L,
TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(),
new NamedThreadFactory("TestTopFieldCollector"));
IndexSearcher concurrentSearcher = new IndexSearcher(ir, service);
IndexSearcher concurrentSearcher = newSearcher(ir, true, true, true);
// Two Sort criteria to instantiate the multi/single comparators.
Sort[] sort = new Sort[] {new Sort(SortField.FIELD_DOC), new Sort()};
@ -178,8 +151,6 @@ public class TestTopFieldCollector extends LuceneTestCase {
CheckHits.checkEqual(q, td.scoreDocs, td2.scoreDocs);
}
service.shutdown();
}
public void testSortWithoutTotalHitTracking() throws Exception {
@ -678,6 +649,10 @@ public class TestTopFieldCollector extends LuceneTestCase {
assertEquals(11, topDocs.totalHits.value);
assertEquals(new TotalHits(11, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), topDocs.totalHits);
leafCollector.setScorer(scorer);
leafCollector2.setScorer(scorer2);
leafCollector3.setScorer(scorer3);
reader.close();
dir.close();
}