diff --git a/lucene/core/src/java/org/apache/lucene/search/HitsThresholdChecker.java b/lucene/core/src/java/org/apache/lucene/search/HitsThresholdChecker.java index 9e42cd7bfef..fb6f6bf4343 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HitsThresholdChecker.java +++ b/lucene/core/src/java/org/apache/lucene/search/HitsThresholdChecker.java @@ -23,96 +23,112 @@ import java.util.concurrent.atomic.AtomicLong; abstract class HitsThresholdChecker { /** Implementation of HitsThresholdChecker which allows global hit counting */ private static class GlobalHitsThresholdChecker extends HitsThresholdChecker { - private final int totalHitsThreshold; - private final AtomicLong globalHitCount; + private final AtomicLong globalHitCount = new AtomicLong(); - public GlobalHitsThresholdChecker(int totalHitsThreshold) { - - if (totalHitsThreshold < 0) { - throw new IllegalArgumentException( - "totalHitsThreshold must be >= 0, got " + totalHitsThreshold); - } - - this.totalHitsThreshold = totalHitsThreshold; - this.globalHitCount = new AtomicLong(); + GlobalHitsThresholdChecker(int totalHitsThreshold) { + super(totalHitsThreshold); + assert totalHitsThreshold != Integer.MAX_VALUE; } @Override - public void incrementHitCount() { + void incrementHitCount() { globalHitCount.incrementAndGet(); } @Override - public boolean isThresholdReached() { - return globalHitCount.getAcquire() > totalHitsThreshold; + boolean isThresholdReached() { + return globalHitCount.getAcquire() > getHitsThreshold(); } @Override - public ScoreMode scoreMode() { - return totalHitsThreshold == Integer.MAX_VALUE ? ScoreMode.COMPLETE : ScoreMode.TOP_SCORES; - } - - @Override - public int getHitsThreshold() { - return totalHitsThreshold; + ScoreMode scoreMode() { + return ScoreMode.TOP_SCORES; } } /** Default implementation of HitsThresholdChecker to be used for single threaded execution */ private static class LocalHitsThresholdChecker extends HitsThresholdChecker { - private final int totalHitsThreshold; private int hitCount; - public LocalHitsThresholdChecker(int totalHitsThreshold) { - - if (totalHitsThreshold < 0) { - throw new IllegalArgumentException( - "totalHitsThreshold must be >= 0, got " + totalHitsThreshold); - } - - this.totalHitsThreshold = totalHitsThreshold; + LocalHitsThresholdChecker(int totalHitsThreshold) { + super(totalHitsThreshold); + assert totalHitsThreshold != Integer.MAX_VALUE; } @Override - public void incrementHitCount() { + void incrementHitCount() { ++hitCount; } @Override - public boolean isThresholdReached() { - return hitCount > totalHitsThreshold; + boolean isThresholdReached() { + return hitCount > getHitsThreshold(); } @Override - public ScoreMode scoreMode() { - return totalHitsThreshold == Integer.MAX_VALUE ? ScoreMode.COMPLETE : ScoreMode.TOP_SCORES; - } - - @Override - public int getHitsThreshold() { - return totalHitsThreshold; + ScoreMode scoreMode() { + return ScoreMode.TOP_SCORES; } } + /** + * No-op implementation of {@link HitsThresholdChecker} that does no counting, as the threshold + * can never be reached. This is useful for cases where early termination is never desired, so + * that the overhead of counting hits can be avoided. + */ + private static final HitsThresholdChecker EXACT_HITS_COUNT_THRESHOLD_CHECKER = + new HitsThresholdChecker(Integer.MAX_VALUE) { + @Override + void incrementHitCount() { + // noop + } + + @Override + boolean isThresholdReached() { + return false; + } + + @Override + ScoreMode scoreMode() { + return ScoreMode.COMPLETE; + } + }; + /* * Returns a threshold checker that is useful for single threaded searches */ - public static HitsThresholdChecker create(final int totalHitsThreshold) { - return new LocalHitsThresholdChecker(totalHitsThreshold); + static HitsThresholdChecker create(final int totalHitsThreshold) { + return totalHitsThreshold == Integer.MAX_VALUE + ? HitsThresholdChecker.EXACT_HITS_COUNT_THRESHOLD_CHECKER + : new LocalHitsThresholdChecker(totalHitsThreshold); } /* * Returns a threshold checker that is based on a shared counter */ - public static HitsThresholdChecker createShared(final int totalHitsThreshold) { - return new GlobalHitsThresholdChecker(totalHitsThreshold); + static HitsThresholdChecker createShared(final int totalHitsThreshold) { + return totalHitsThreshold == Integer.MAX_VALUE + ? HitsThresholdChecker.EXACT_HITS_COUNT_THRESHOLD_CHECKER + : new GlobalHitsThresholdChecker(totalHitsThreshold); } - public abstract void incrementHitCount(); + private final int totalHitsThreshold; - public abstract ScoreMode scoreMode(); + HitsThresholdChecker(int totalHitsThreshold) { + if (totalHitsThreshold < 0) { + throw new IllegalArgumentException( + "totalHitsThreshold must be >= 0, got " + totalHitsThreshold); + } + this.totalHitsThreshold = totalHitsThreshold; + } - public abstract int getHitsThreshold(); + final int getHitsThreshold() { + return totalHitsThreshold; + } - public abstract boolean isThresholdReached(); + abstract boolean isThresholdReached(); + + abstract ScoreMode scoreMode(); + + abstract void incrementHitCount(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java b/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java index 0023d718424..2ecb5e08694 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java @@ -488,11 +488,13 @@ public abstract class TopFieldCollector extends TopDocsCollector { */ public static CollectorManager createSharedManager( Sort sort, int numHits, FieldDoc after, int totalHitsThreshold) { - return new CollectorManager<>() { + int totalHitsMax = Math.max(totalHitsThreshold, numHits); + return new CollectorManager<>() { private final HitsThresholdChecker hitsThresholdChecker = - HitsThresholdChecker.createShared(Math.max(totalHitsThreshold, numHits)); - private final MaxScoreAccumulator minScoreAcc = new MaxScoreAccumulator(); + HitsThresholdChecker.createShared(totalHitsMax); + private final MaxScoreAccumulator minScoreAcc = + totalHitsMax == Integer.MAX_VALUE ? null : new MaxScoreAccumulator(); @Override public TopFieldCollector newCollector() throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/search/TopScoreDocCollector.java b/lucene/core/src/java/org/apache/lucene/search/TopScoreDocCollector.java index 0c4ccdcda7d..a6de5ac3c13 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopScoreDocCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopScoreDocCollector.java @@ -255,11 +255,14 @@ public abstract class TopScoreDocCollector extends TopDocsCollector { */ public static CollectorManager createSharedManager( int numHits, ScoreDoc after, int totalHitsThreshold) { + + int totalHitsMax = Math.max(totalHitsThreshold, numHits); return new CollectorManager<>() { private final HitsThresholdChecker hitsThresholdChecker = - HitsThresholdChecker.createShared(Math.max(totalHitsThreshold, numHits)); - private final MaxScoreAccumulator minScoreAcc = new MaxScoreAccumulator(); + HitsThresholdChecker.createShared(totalHitsMax); + private final MaxScoreAccumulator minScoreAcc = + totalHitsMax == Integer.MAX_VALUE ? null : new MaxScoreAccumulator(); @Override public TopScoreDocCollector newCollector() throws IOException { diff --git a/lucene/facet/src/java/org/apache/lucene/facet/DrillSideways.java b/lucene/facet/src/java/org/apache/lucene/facet/DrillSideways.java index 5a09720adda..372ac4ade0c 100644 --- a/lucene/facet/src/java/org/apache/lucene/facet/DrillSideways.java +++ b/lucene/facet/src/java/org/apache/lucene/facet/DrillSideways.java @@ -18,7 +18,6 @@ package org.apache.lucene.facet; import java.io.IOException; import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -198,23 +197,7 @@ public class DrillSideways { final int fTopN = Math.min(topN, limit); final CollectorManager collectorManager = - new CollectorManager<>() { - - @Override - public TopFieldCollector newCollector() { - return TopFieldCollector.create(sort, fTopN, after, Integer.MAX_VALUE); - } - - @Override - public TopFieldDocs reduce(Collection collectors) { - final TopFieldDocs[] topFieldDocs = new TopFieldDocs[collectors.size()]; - int pos = 0; - for (TopFieldCollector collector : collectors) - topFieldDocs[pos++] = collector.topDocs(); - return TopDocs.merge(sort, topN, topFieldDocs); - } - }; - + TopFieldCollector.createSharedManager(sort, fTopN, after, Integer.MAX_VALUE); final ConcurrentDrillSidewaysResult r = search(query, collectorManager); TopFieldDocs topDocs = r.collectorResult; @@ -247,24 +230,7 @@ public class DrillSideways { final int fTopN = Math.min(topN, limit); final CollectorManager collectorManager = - new CollectorManager<>() { - - @Override - public TopScoreDocCollector newCollector() { - return TopScoreDocCollector.create(fTopN, after, Integer.MAX_VALUE); - } - - @Override - public TopDocs reduce(Collection collectors) { - final TopDocs[] topDocs = new TopDocs[collectors.size()]; - int pos = 0; - for (TopScoreDocCollector collector : collectors) { - topDocs[pos++] = collector.topDocs(); - } - return TopDocs.merge(topN, topDocs); - } - }; - + TopScoreDocCollector.createSharedManager(fTopN, after, Integer.MAX_VALUE); final ConcurrentDrillSidewaysResult r = search(query, collectorManager); return new DrillSidewaysResult( r.facets,