Remove HitsThresholdChecker. (#13943)

`TopScoreDocCollectorManager` has a dependency on `HitsThresholdChecker`, which
is essentially a shared counter that is incremented until it reaches the total
hits threshold, when the scorer can start dynamically pruning hits.

A consequence of this removal is that dynamic pruning may start later, as soon
as:
 - either the current slice collected `totalHitsThreshold` hits,
 - or another slice collected `totalHitsThreshold` hits and the current slice
   collected enough hits (up to 1,024) to check the shared
   `MaxScoreAccumulator`.

So in short, it exchanges a bit more work globally in favor of a bit less
contention. A longer-term goal of mine is to stop specializing our
`CollectorManager`s based on whether they are going to be used concurrently or
not.
This commit is contained in:
Adrien Grand 2024-10-28 15:50:53 +01:00 committed by GitHub
parent 81ab3b9722
commit 937432acd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 52 additions and 216 deletions

View File

@ -72,6 +72,8 @@ Optimizations
* GITHUB#13899: Check ahead if we can get the count. (Lu Xugang)
* GITHUB#13943: Removed shared `HitsThresholdChecker`, which reduces overhead
but may delay a bit when dynamic pruning kicks in. (Adrien Grand)
Bug Fixes
---------------------

View File

@ -1,147 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.util.concurrent.atomic.LongAdder;
/** Used for defining custom algorithms to allow searches to early terminate */
abstract class HitsThresholdChecker {
/** Implementation of HitsThresholdChecker which allows global hit counting */
private static class GlobalHitsThresholdChecker extends HitsThresholdChecker {
private final LongAdder globalHitCount = new LongAdder();
// Cache whether the threshold has been reached already. It is not volatile or synchronized on
// purpose to contain the overhead of reading the value similarly to what String#hashCode()
// does. This does not affect correctness.
private boolean thresholdReached = false;
GlobalHitsThresholdChecker(int totalHitsThreshold) {
super(totalHitsThreshold);
assert totalHitsThreshold != Integer.MAX_VALUE;
}
@Override
void incrementHitCount() {
if (thresholdReached == false) {
globalHitCount.increment();
}
}
@Override
boolean isThresholdReached() {
if (thresholdReached) {
return true;
}
if (globalHitCount.longValue() > getHitsThreshold()) {
thresholdReached = true;
return true;
}
return false;
}
@Override
ScoreMode scoreMode() {
return ScoreMode.TOP_SCORES;
}
}
/** Default implementation of HitsThresholdChecker to be used for single threaded execution */
private static class LocalHitsThresholdChecker extends HitsThresholdChecker {
private int hitCount;
LocalHitsThresholdChecker(int totalHitsThreshold) {
super(totalHitsThreshold);
assert totalHitsThreshold != Integer.MAX_VALUE;
}
@Override
void incrementHitCount() {
++hitCount;
}
@Override
boolean isThresholdReached() {
return hitCount > getHitsThreshold();
}
@Override
ScoreMode scoreMode() {
return ScoreMode.TOP_SCORES;
}
}
/**
* No-op implementation of {@link HitsThresholdChecker} that does no counting, as the threshold
* can never be reached. This is useful for cases where early termination is never desired, so
* that the overhead of counting hits can be avoided.
*/
private static final HitsThresholdChecker EXACT_HITS_COUNT_THRESHOLD_CHECKER =
new HitsThresholdChecker(Integer.MAX_VALUE) {
@Override
void incrementHitCount() {
// noop
}
@Override
boolean isThresholdReached() {
return false;
}
@Override
ScoreMode scoreMode() {
return ScoreMode.COMPLETE;
}
};
/*
* Returns a threshold checker that is useful for single threaded searches
*/
static HitsThresholdChecker create(final int totalHitsThreshold) {
return totalHitsThreshold == Integer.MAX_VALUE
? HitsThresholdChecker.EXACT_HITS_COUNT_THRESHOLD_CHECKER
: new LocalHitsThresholdChecker(totalHitsThreshold);
}
/*
* Returns a threshold checker that is based on a shared counter
*/
static HitsThresholdChecker createShared(final int totalHitsThreshold) {
return totalHitsThreshold == Integer.MAX_VALUE
? HitsThresholdChecker.EXACT_HITS_COUNT_THRESHOLD_CHECKER
: new GlobalHitsThresholdChecker(totalHitsThreshold);
}
private final int totalHitsThreshold;
HitsThresholdChecker(int totalHitsThreshold) {
if (totalHitsThreshold < 0) {
throw new IllegalArgumentException(
"totalHitsThreshold must be >= 0, got " + totalHitsThreshold);
}
this.totalHitsThreshold = totalHitsThreshold;
}
final int getHitsThreshold() {
return totalHitsThreshold;
}
abstract boolean isThresholdReached();
abstract ScoreMode scoreMode();
abstract void incrementHitCount();
}

View File

@ -71,15 +71,14 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
}
void countHit(int doc) throws IOException {
++totalHits;
hitsThresholdChecker.incrementHitCount();
int hitCountSoFar = ++totalHits;
if (minScoreAcc != null && (totalHits & minScoreAcc.modInterval) == 0) {
if (minScoreAcc != null && (hitCountSoFar & minScoreAcc.modInterval) == 0) {
updateGlobalMinCompetitiveScore(scorer);
}
if (scoreMode.isExhaustive() == false
&& totalHitsRelation == TotalHits.Relation.EQUAL_TO
&& hitsThresholdChecker.isThresholdReached()) {
&& totalHits > totalHitsThreshold) {
// for the first time hitsThreshold is reached, notify comparator about this
comparator.setHitsThresholdReached();
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
@ -92,7 +91,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
// this document is larger than anything else in the queue, and
// therefore not competitive.
if (searchSortPartOfIndexSort) {
if (hitsThresholdChecker.isThresholdReached()) {
if (totalHits > totalHitsThreshold) {
totalHitsRelation = Relation.GREATER_THAN_OR_EQUAL_TO;
throw new CollectionTerminatedException();
} else {
@ -180,9 +179,9 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
Sort sort,
FieldValueHitQueue<Entry> queue,
int numHits,
HitsThresholdChecker hitsThresholdChecker,
int totalHitsThreshold,
MaxScoreAccumulator minScoreAcc) {
super(queue, numHits, hitsThresholdChecker, sort.needsScores(), minScoreAcc);
super(queue, numHits, totalHitsThreshold, sort.needsScores(), minScoreAcc);
this.sort = sort;
this.queue = queue;
}
@ -235,9 +234,9 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
FieldValueHitQueue<Entry> queue,
FieldDoc after,
int numHits,
HitsThresholdChecker hitsThresholdChecker,
int totalHitsThreshold,
MaxScoreAccumulator minScoreAcc) {
super(queue, numHits, hitsThresholdChecker, sort.needsScores(), minScoreAcc);
super(queue, numHits, totalHitsThreshold, sort.needsScores(), minScoreAcc);
this.sort = sort;
this.queue = queue;
this.after = after;
@ -301,7 +300,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
private static final ScoreDoc[] EMPTY_SCOREDOCS = new ScoreDoc[0];
final int numHits;
final HitsThresholdChecker hitsThresholdChecker;
final int totalHitsThreshold;
final FieldComparator<?> firstComparator;
final boolean canSetMinScore;
@ -327,25 +326,25 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
private TopFieldCollector(
FieldValueHitQueue<Entry> pq,
int numHits,
HitsThresholdChecker hitsThresholdChecker,
int totalHitsThreshold,
boolean needsScores,
MaxScoreAccumulator minScoreAcc) {
super(pq);
this.needsScores = needsScores;
this.numHits = numHits;
this.hitsThresholdChecker = hitsThresholdChecker;
this.totalHitsThreshold = Math.max(totalHitsThreshold, numHits);
this.numComparators = pq.getComparators().length;
this.firstComparator = pq.getComparators()[0];
int reverseMul = pq.reverseMul[0];
if (firstComparator.getClass().equals(FieldComparator.RelevanceComparator.class)
&& reverseMul == 1 // if the natural sort is preserved (sort by descending relevance)
&& hitsThresholdChecker.getHitsThreshold() != Integer.MAX_VALUE) {
&& totalHitsThreshold != Integer.MAX_VALUE) {
scoreMode = ScoreMode.TOP_SCORES;
canSetMinScore = true;
} else {
canSetMinScore = false;
if (hitsThresholdChecker.getHitsThreshold() != Integer.MAX_VALUE) {
if (totalHitsThreshold != Integer.MAX_VALUE) {
scoreMode = needsScores ? ScoreMode.TOP_DOCS_WITH_SCORES : ScoreMode.TOP_DOCS;
} else {
scoreMode = needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
@ -361,10 +360,10 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
protected void updateGlobalMinCompetitiveScore(Scorable scorer) throws IOException {
assert minScoreAcc != null;
if (canSetMinScore && hitsThresholdChecker.isThresholdReached()) {
// we can start checking the global maximum score even
// if the local queue is not full because the threshold
// is reached.
if (canSetMinScore) {
// we can start checking the global maximum score even if the local queue is not full or if
// the threshold is not reached on the local competitor: the fact that there is a shared min
// competitive score implies that one of the collectors hit its totalHitsThreshold already
long maxMinScore = minScoreAcc.getRaw();
float score;
if (maxMinScore != Long.MIN_VALUE
@ -377,7 +376,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
}
protected void updateMinCompetitiveScore(Scorable scorer) throws IOException {
if (canSetMinScore && queueFull && hitsThresholdChecker.isThresholdReached()) {
if (canSetMinScore && queueFull && totalHits > totalHitsThreshold) {
assert bottom != null;
float minScore = (float) firstComparator.value(bottom.slot);
if (minScore > minCompetitiveScore) {

View File

@ -32,7 +32,7 @@ public class TopFieldCollectorManager implements CollectorManager<TopFieldCollec
private final Sort sort;
private final int numHits;
private final FieldDoc after;
private final HitsThresholdChecker hitsThresholdChecker;
private final int totalHitsThreshold;
private final MaxScoreAccumulator minScoreAcc;
private final List<TopFieldCollector> collectors;
private final boolean supportsConcurrency;
@ -89,10 +89,7 @@ public class TopFieldCollectorManager implements CollectorManager<TopFieldCollec
this.numHits = numHits;
this.after = after;
this.supportsConcurrency = supportsConcurrency;
this.hitsThresholdChecker =
supportsConcurrency
? HitsThresholdChecker.createShared(Math.max(totalHitsThreshold, numHits))
: HitsThresholdChecker.create(Math.max(totalHitsThreshold, numHits));
this.totalHitsThreshold = totalHitsThreshold;
this.minScoreAcc =
supportsConcurrency && totalHitsThreshold != Integer.MAX_VALUE
? new MaxScoreAccumulator()
@ -162,7 +159,7 @@ public class TopFieldCollectorManager implements CollectorManager<TopFieldCollec
}
collector =
new TopFieldCollector.SimpleFieldCollector(
sort, queue, numHits, hitsThresholdChecker, minScoreAcc);
sort, queue, numHits, totalHitsThreshold, minScoreAcc);
} else {
if (after.fields == null) {
throw new IllegalArgumentException(
@ -178,7 +175,7 @@ public class TopFieldCollectorManager implements CollectorManager<TopFieldCollec
}
collector =
new TopFieldCollector.PagingFieldCollector(
sort, queue, after, numHits, hitsThresholdChecker, minScoreAcc);
sort, queue, after, numHits, totalHitsThreshold, minScoreAcc);
}
collectors.add(collector);

View File

@ -45,8 +45,8 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
static class SimpleTopScoreDocCollector extends TopScoreDocCollector {
SimpleTopScoreDocCollector(
int numHits, HitsThresholdChecker hitsThresholdChecker, MaxScoreAccumulator minScoreAcc) {
super(numHits, hitsThresholdChecker, minScoreAcc);
int numHits, int totalHitsThreshold, MaxScoreAccumulator minScoreAcc) {
super(numHits, totalHitsThreshold, minScoreAcc);
}
@Override
@ -71,7 +71,6 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
float score = scorer.score();
int hitCountSoFar = ++totalHits;
hitsThresholdChecker.incrementHitCount();
if (minScoreAcc != null && (hitCountSoFar & minScoreAcc.modInterval) == 0) {
updateGlobalMinCompetitiveScore(scorer);
@ -80,7 +79,7 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
if (score <= pqTop.score) {
// Note: for queries that match lots of hits, this is the common case: most hits are not
// competitive.
if (totalHitsRelation == TotalHits.Relation.EQUAL_TO) {
if (hitCountSoFar == totalHitsThreshold + 1) {
// we just reached totalHitsThreshold, we can start setting the min
// competitive score now
updateMinCompetitiveScore(scorer);
@ -108,11 +107,8 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
private final ScoreDoc after;
PagingTopScoreDocCollector(
int numHits,
ScoreDoc after,
HitsThresholdChecker hitsThresholdChecker,
MaxScoreAccumulator minScoreAcc) {
super(numHits, hitsThresholdChecker, minScoreAcc);
int numHits, ScoreDoc after, int totalHitsThreshold, MaxScoreAccumulator minScoreAcc) {
super(numHits, totalHitsThreshold, minScoreAcc);
this.after = after;
}
@ -158,7 +154,6 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
float score = scorer.score();
int hitCountSoFar = ++totalHits;
hitsThresholdChecker.incrementHitCount();
if (minScoreAcc != null && (hitCountSoFar & minScoreAcc.modInterval) == 0) {
updateGlobalMinCompetitiveScore(scorer);
@ -178,8 +173,8 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
if (score <= pqTop.score) {
// Note: for queries that match lots of hits, this is the common case: most hits are not
// competitive.
if (totalHitsRelation == TotalHits.Relation.EQUAL_TO) {
// we just reached totalHitsThreshold, we can start setting the min
if (hitCountSoFar == totalHitsThreshold + 1) {
// we just exceeded totalHitsThreshold, we can start setting the min
// competitive score now
updateMinCompetitiveScore(scorer);
}
@ -204,20 +199,18 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
int docBase;
ScoreDoc pqTop;
final HitsThresholdChecker hitsThresholdChecker;
final int totalHitsThreshold;
final MaxScoreAccumulator minScoreAcc;
float minCompetitiveScore;
// prevents instantiation
TopScoreDocCollector(
int numHits, HitsThresholdChecker hitsThresholdChecker, MaxScoreAccumulator minScoreAcc) {
TopScoreDocCollector(int numHits, int totalHitsThreshold, MaxScoreAccumulator minScoreAcc) {
super(new HitQueue(numHits, true));
assert hitsThresholdChecker != null;
// HitQueue implements getSentinelObject to return a ScoreDoc, so we know
// that at this point top() is already initialized.
pqTop = pq.top();
this.hitsThresholdChecker = hitsThresholdChecker;
this.totalHitsThreshold = totalHitsThreshold;
this.minScoreAcc = minScoreAcc;
}
@ -232,7 +225,7 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
@Override
public ScoreMode scoreMode() {
return hitsThresholdChecker.scoreMode();
return totalHitsThreshold == Integer.MAX_VALUE ? ScoreMode.COMPLETE : ScoreMode.TOP_SCORES;
}
protected void updateGlobalMinCompetitiveScore(Scorable scorer) throws IOException {
@ -245,7 +238,6 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
float score = MaxScoreAccumulator.toScore(maxMinScore);
score = docBase >= MaxScoreAccumulator.docId(maxMinScore) ? Math.nextUp(score) : score;
if (score > minCompetitiveScore) {
assert hitsThresholdChecker.isThresholdReached();
scorer.setMinCompetitiveScore(score);
minCompetitiveScore = score;
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
@ -254,7 +246,7 @@ public abstract class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
}
protected void updateMinCompetitiveScore(Scorable scorer) throws IOException {
if (hitsThresholdChecker.isThresholdReached()) {
if (totalHits > totalHitsThreshold) {
// since we tie-break on doc id and collect in doc id order, we can require
// the next float
// pqTop is never null since TopScoreDocCollector fills the priority queue with sentinel

View File

@ -29,7 +29,7 @@ public class TopScoreDocCollectorManager
implements CollectorManager<TopScoreDocCollector, TopDocs> {
private final int numHits;
private final ScoreDoc after;
private final HitsThresholdChecker hitsThresholdChecker;
private final int totalHitsThreshold;
private final MaxScoreAccumulator minScoreAcc;
private final boolean supportsConcurrency;
private boolean collectorCreated;
@ -71,10 +71,7 @@ public class TopScoreDocCollectorManager
this.numHits = numHits;
this.after = after;
this.supportsConcurrency = supportsConcurrency;
this.hitsThresholdChecker =
supportsConcurrency
? HitsThresholdChecker.createShared(Math.max(totalHitsThreshold, numHits))
: HitsThresholdChecker.create(Math.max(totalHitsThreshold, numHits));
this.totalHitsThreshold = Math.max(totalHitsThreshold, numHits);
this.minScoreAcc =
supportsConcurrency && totalHitsThreshold != Integer.MAX_VALUE
? new MaxScoreAccumulator()
@ -141,10 +138,10 @@ public class TopScoreDocCollectorManager
if (after == null) {
return new TopScoreDocCollector.SimpleTopScoreDocCollector(
numHits, hitsThresholdChecker, minScoreAcc);
numHits, totalHitsThreshold, minScoreAcc);
} else {
return new TopScoreDocCollector.PagingTopScoreDocCollector(
numHits, after, hitsThresholdChecker, minScoreAcc);
numHits, after, totalHitsThreshold, minScoreAcc);
}
}

View File

@ -529,26 +529,24 @@ public class TestTopDocsCollector extends LuceneTestCase {
scorer.score = 2;
leafCollector.collect(1);
assertEquals(2f, MaxScoreAccumulator.toScore(minValueChecker.getRaw()), 0f);
assertEquals(Math.nextUp(2f), scorer.minCompetitiveScore, 0f);
assertNull(scorer2.minCompetitiveScore);
assertEquals(Long.MIN_VALUE, minValueChecker.getRaw());
assertNull(scorer.minCompetitiveScore);
scorer2.score = 9;
leafCollector2.collect(1);
assertEquals(6f, MaxScoreAccumulator.toScore(minValueChecker.getRaw()), 0f);
assertEquals(Math.nextUp(2f), scorer.minCompetitiveScore, 0f);
assertEquals(Math.nextUp(6f), scorer2.minCompetitiveScore, 0f);
assertEquals(Long.MIN_VALUE, minValueChecker.getRaw());
assertNull(scorer2.minCompetitiveScore);
scorer2.score = 7;
leafCollector2.collect(2);
assertEquals(MaxScoreAccumulator.toScore(minValueChecker.getRaw()), 7f, 0f);
assertEquals(Math.nextUp(2f), scorer.minCompetitiveScore, 0f);
assertNull(scorer.minCompetitiveScore);
assertEquals(Math.nextUp(7f), scorer2.minCompetitiveScore, 0f);
scorer2.score = 1;
leafCollector2.collect(3);
assertEquals(MaxScoreAccumulator.toScore(minValueChecker.getRaw()), 7f, 0f);
assertEquals(Math.nextUp(2f), scorer.minCompetitiveScore, 0f);
assertNull(scorer.minCompetitiveScore);
assertEquals(Math.nextUp(7f), scorer2.minCompetitiveScore, 0f);
scorer.score = 10;

View File

@ -587,26 +587,24 @@ public class TestTopFieldCollector extends LuceneTestCase {
scorer.score = 2;
leafCollector.collect(1);
assertEquals(2f, MaxScoreAccumulator.toScore(minValueChecker.getRaw()), 0f);
assertEquals(2f, scorer.minCompetitiveScore, 0f);
assertNull(scorer2.minCompetitiveScore);
assertEquals(Long.MIN_VALUE, minValueChecker.getRaw());
assertNull(scorer.minCompetitiveScore);
scorer2.score = 9;
leafCollector2.collect(1);
assertEquals(6f, MaxScoreAccumulator.toScore(minValueChecker.getRaw()), 0f);
assertEquals(2f, scorer.minCompetitiveScore, 0f);
assertEquals(6f, scorer2.minCompetitiveScore, 0f);
assertEquals(Long.MIN_VALUE, minValueChecker.getRaw());
assertNull(scorer2.minCompetitiveScore);
scorer2.score = 7;
leafCollector2.collect(2);
assertEquals(7f, MaxScoreAccumulator.toScore(minValueChecker.getRaw()), 0f);
assertEquals(2f, scorer.minCompetitiveScore, 0f);
assertNull(scorer.minCompetitiveScore);
assertEquals(7f, scorer2.minCompetitiveScore, 0f);
scorer2.score = 1;
leafCollector2.collect(3);
assertEquals(7f, MaxScoreAccumulator.toScore(minValueChecker.getRaw()), 0f);
assertEquals(2f, scorer.minCompetitiveScore, 0f);
assertNull(scorer.minCompetitiveScore);
assertEquals(7f, scorer2.minCompetitiveScore, 0f);
scorer.score = 10;