Fix handling of terminate_after when size is 0 (#58212)

`terminate_after` is ignored on search requests that don't return top hits (`size` set to 0)
and do not tracked the number of hits accurately (`track_total_hits`).
We use early termination when the number of hits to track is reached during collection
but this breaks the hard termination of `terminate_after` if it happens before we reached
the `terminate_after` value.
This change ensures that we continue to check `terminate_after` even if the tracking of total
hits has reached the provided value.

Closes #57624
This commit is contained in:
Jim Ferenczi 2020-06-24 13:15:44 +02:00 committed by jimczi
parent 796cb9e9ca
commit ec8d5ec79c
2 changed files with 59 additions and 5 deletions

View File

@ -24,6 +24,7 @@ import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.lucene.MinimumScoreCollector;
import org.elasticsearch.common.lucene.search.FilteredCollector;
@ -41,6 +42,17 @@ import static org.elasticsearch.search.profile.query.CollectorResult.REASON_SEAR
import static org.elasticsearch.search.profile.query.CollectorResult.REASON_SEARCH_TERMINATE_AFTER_COUNT;
abstract class QueryCollectorContext {
private static final Collector EMPTY_COLLECTOR = new SimpleCollector() {
@Override
public void collect(int doc) {
}
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}
};
private String profilerName;
QueryCollectorContext(String profilerName) {
@ -124,7 +136,7 @@ abstract class QueryCollectorContext {
static QueryCollectorContext createMultiCollectorContext(Collection<Collector> subs) {
return new QueryCollectorContext(REASON_SEARCH_MULTI) {
@Override
Collector create(Collector in) throws IOException {
Collector create(Collector in) {
List<Collector> subCollectors = new ArrayList<> ();
subCollectors.add(in);
subCollectors.addAll(subs);
@ -132,7 +144,7 @@ abstract class QueryCollectorContext {
}
@Override
protected InternalProfileCollector createWithProfiler(InternalProfileCollector in) throws IOException {
protected InternalProfileCollector createWithProfiler(InternalProfileCollector in) {
final List<InternalProfileCollector> subCollectors = new ArrayList<> ();
subCollectors.add(in);
if (subs.stream().anyMatch((col) -> col instanceof InternalProfileCollector == false)) {
@ -152,12 +164,20 @@ abstract class QueryCollectorContext {
*/
static QueryCollectorContext createEarlyTerminationCollectorContext(int numHits) {
return new QueryCollectorContext(REASON_SEARCH_TERMINATE_AFTER_COUNT) {
private EarlyTerminatingCollector collector;
private Collector collector;
/**
* Creates a {@link MultiCollector} to ensure that the {@link EarlyTerminatingCollector}
* can terminate the collection independently of the provided <code>in</code> {@link Collector}.
*/
@Override
Collector create(Collector in) throws IOException {
Collector create(Collector in) {
assert collector == null;
this.collector = new EarlyTerminatingCollector(in, numHits, true);
List<Collector> subCollectors = new ArrayList<> ();
subCollectors.add(new EarlyTerminatingCollector(EMPTY_COLLECTOR, numHits, true));
subCollectors.add(in);
this.collector = MultiCollector.wrap(subCollectors);
return collector;
}
};

View File

@ -452,6 +452,40 @@ public class QueryPhaseTests extends IndexShardTestCase {
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0));
assertThat(collector.getTotalHits(), equalTo(1));
}
// tests with trackTotalHits and terminateAfter
context.terminateAfter(10);
context.setSize(0);
for (int trackTotalHits : new int[] { -1, 3, 76, 100}) {
context.trackTotalHitsUpTo(trackTotalHits);
TotalHitCountCollector collector = new TotalHitCountCollector();
context.queryCollectors().put(TotalHitCountCollector.class, collector);
QueryPhase.executeInternal(context);
assertTrue(context.queryResult().terminatedEarly());
if (trackTotalHits == -1) {
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L));
} else {
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) Math.min(trackTotalHits, 10)));
}
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0));
assertThat(collector.getTotalHits(), equalTo(10));
}
context.terminateAfter(7);
context.setSize(10);
for (int trackTotalHits : new int[] { -1, 3, 75, 100}) {
context.trackTotalHitsUpTo(trackTotalHits);
EarlyTerminatingCollector collector = new EarlyTerminatingCollector(new TotalHitCountCollector(), 1, false);
context.queryCollectors().put(EarlyTerminatingCollector.class, collector);
QueryPhase.executeInternal(context);
assertTrue(context.queryResult().terminatedEarly());
if (trackTotalHits == -1) {
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L));
} else {
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(7L));
}
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(7));
}
reader.close();
dir.close();
}