diff --git a/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java b/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java index 3aaa640f62f..fcf70a4f98c 100644 --- a/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java +++ b/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java @@ -174,6 +174,16 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext { } abstract static class SimpleTopDocsCollectorContext extends TopDocsCollectorContext { + + private static TopDocsCollector createCollector(@Nullable SortAndFormats sortAndFormats, int numHits, + @Nullable ScoreDoc searchAfter, int hitCountThreshold) { + if (sortAndFormats == null) { + return TopScoreDocCollector.create(numHits, searchAfter, hitCountThreshold); + } else { + return TopFieldCollector.create(sortAndFormats.sort, numHits, (FieldDoc) searchAfter, hitCountThreshold); + } + } + private final @Nullable SortAndFormats sortAndFormats; private final Collector collector; private final Supplier totalHitsSupplier; @@ -201,12 +211,27 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext { boolean hasFilterCollector) throws IOException { super(REASON_SEARCH_TOP_HITS, numHits); this.sortAndFormats = sortAndFormats; + + // implicit total hit counts are valid only when there is no filter collector in the chain + final int hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query); + final TopDocsCollector topDocsCollector; + if (hitCount == -1 && trackTotalHits) { + topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, Integer.MAX_VALUE); + topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); + totalHitsSupplier = () -> topDocsSupplier.get().totalHits; + } else { + topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1); // don't compute hit counts via the collector + topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); + if (hitCount == -1) { + assert trackTotalHits == false; + totalHitsSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + } else { + totalHitsSupplier = () -> new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO); + } + } + MaxScoreCollector maxScoreCollector = null; if (sortAndFormats == null) { - final TopDocsCollector topDocsCollector = TopScoreDocCollector.create(numHits, searchAfter, Integer.MAX_VALUE); - this.collector = topDocsCollector; - this.topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); - this.totalHitsSupplier = () -> topDocsSupplier.get().totalHits; - this.maxScoreSupplier = () -> { + maxScoreSupplier = () -> { TopDocs topDocs = topDocsSupplier.get(); if (topDocs.scoreDocs.length == 0) { return Float.NaN; @@ -214,42 +239,13 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext { return topDocs.scoreDocs[0].score; } }; + } else if (trackMaxScore) { + maxScoreCollector = new MaxScoreCollector(); + maxScoreSupplier = maxScoreCollector::getMaxScore; } else { - /** - * We explicitly don't track total hits in the topdocs collector, it can early terminate - * if the sort matches the index sort. - */ - final TopDocsCollector topDocsCollector = TopFieldCollector.create(sortAndFormats.sort, numHits, - (FieldDoc) searchAfter, 1); - this.topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); - TotalHitCountCollector hitCountCollector = null; - if (trackTotalHits) { - // implicit total hit counts are valid only when there is no filter collector in the chain - int count = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query); - if (count != -1) { - // we can extract the total count from the shard statistics directly - this.totalHitsSupplier = () -> new TotalHits(count, TotalHits.Relation.EQUAL_TO); - } else { - // wrap a collector that counts the total number of hits even - // if the top docs collector terminates early - final TotalHitCountCollector countingCollector = new TotalHitCountCollector(); - hitCountCollector = countingCollector; - this.totalHitsSupplier = () -> new TotalHits(countingCollector.getTotalHits(), TotalHits.Relation.EQUAL_TO); - } - } else { - // total hit count is not needed - // for bwc hit count is set to 0, it will be converted to -1 by the coordinating node - this.totalHitsSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); - } - MaxScoreCollector maxScoreCollector = null; - if (trackMaxScore) { - maxScoreCollector = new MaxScoreCollector(); - maxScoreSupplier = maxScoreCollector::getMaxScore; - } else { - maxScoreSupplier = () -> Float.NaN; - } - collector = MultiCollector.wrap(topDocsCollector, hitCountCollector, maxScoreCollector); + maxScoreSupplier = () -> Float.NaN; } + this.collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector); } @Override