diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java index c92fe0f9e34..aa26a72808c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java +++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java @@ -42,54 +42,62 @@ public class TimeLimitingKnnCollectorManager implements KnnCollectorManager { if (queryTimeout == null) { return collector; } - return new KnnCollector() { - @Override - public boolean earlyTerminated() { - return queryTimeout.shouldExit() || collector.earlyTerminated(); - } + return new TimeLimitingKnnCollector(collector); + } - @Override - public void incVisitedCount(int count) { - collector.incVisitedCount(count); - } + class TimeLimitingKnnCollector implements KnnCollector { + private final KnnCollector collector; - @Override - public long visitedCount() { - return collector.visitedCount(); - } + TimeLimitingKnnCollector(KnnCollector collector) { + this.collector = collector; + } - @Override - public long visitLimit() { - return collector.visitLimit(); - } + @Override + public boolean earlyTerminated() { + return queryTimeout.shouldExit() || collector.earlyTerminated(); + } - @Override - public int k() { - return collector.k(); - } + @Override + public void incVisitedCount(int count) { + collector.incVisitedCount(count); + } - @Override - public boolean collect(int docId, float similarity) { - return collector.collect(docId, similarity); - } + @Override + public long visitedCount() { + return collector.visitedCount(); + } - @Override - public float minCompetitiveSimilarity() { - return collector.minCompetitiveSimilarity(); - } + @Override + public long visitLimit() { + return collector.visitLimit(); + } - @Override - public TopDocs topDocs() { - TopDocs docs = collector.topDocs(); + @Override + public int k() { + return collector.k(); + } - // Mark results as partial if timeout is met - TotalHits.Relation relation = - queryTimeout.shouldExit() - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : docs.totalHits.relation; + @Override + public boolean collect(int docId, float similarity) { + return collector.collect(docId, similarity); + } - return new TopDocs(new TotalHits(docs.totalHits.value, relation), docs.scoreDocs); - } - }; + @Override + public float minCompetitiveSimilarity() { + return collector.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + TopDocs docs = collector.topDocs(); + + // Mark results as partial if timeout is met + TotalHits.Relation relation = + queryTimeout.shouldExit() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : docs.totalHits.relation; + + return new TopDocs(new TotalHits(docs.totalHits.value, relation), docs.scoreDocs); + } } } diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index a37bb4a4dc0..be1526503ff 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -785,7 +785,8 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { noTimeoutManager.newCollector(Integer.MAX_VALUE, searcher.leafContexts.get(0)); // Check that a normal collector is created without timeout - assertTrue(noTimeoutCollector instanceof TopKnnCollector); + assertFalse( + noTimeoutCollector instanceof TimeLimitingKnnCollectorManager.TimeLimitingKnnCollector); noTimeoutCollector.collect(0, 0); assertFalse(noTimeoutCollector.earlyTerminated());