diff --git a/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java b/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java index da1d94a6c..3640ca700 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java @@ -66,10 +66,14 @@ abstract class StreamQueries { private volatile Iterator> currentScrollHits = searchHits.iterator(); private volatile boolean continueScroll = currentScrollHits.hasNext(); private volatile ScrollState scrollState = new ScrollState(searchHits.getScrollId()); + private volatile boolean isClosed = false; @Override public void close() { - clearScrollConsumer.accept(scrollState.getScrollIds()); + if (!isClosed) { + clearScrollConsumer.accept(scrollState.getScrollIds()); + isClosed = true; + } } @Override @@ -96,18 +100,24 @@ abstract class StreamQueries { @Override public boolean hasNext() { - if (!continueScroll || (maxCount > 0 && currentCount.get() >= maxCount)) { - return false; + boolean hasNext = false; + + if (!isClosed && continueScroll && (maxCount <= 0 || currentCount.get() < maxCount)) { + + if (!currentScrollHits.hasNext()) { + SearchScrollHits nextPage = continueScrollFunction.apply(scrollState.getScrollId()); + currentScrollHits = nextPage.iterator(); + scrollState.updateScrollId(nextPage.getScrollId()); + continueScroll = currentScrollHits.hasNext(); + } + hasNext = currentScrollHits.hasNext(); } - if (!currentScrollHits.hasNext()) { - SearchScrollHits nextPage = continueScrollFunction.apply(scrollState.getScrollId()); - currentScrollHits = nextPage.iterator(); - scrollState.updateScrollId(nextPage.getScrollId()); - continueScroll = currentScrollHits.hasNext(); + if (!hasNext) { + close(); } - return currentScrollHits.hasNext(); + return hasNext; } @Override diff --git a/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java b/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java index 580202780..518d0c61f 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.springframework.data.util.StreamUtils; @@ -39,6 +40,8 @@ public class StreamQueriesTest { // given List> hits = new ArrayList<>(); hits.add(getOneSearchHit()); + hits.add(getOneSearchHit()); + hits.add(getOneSearchHit()); SearchScrollHits searchHits = newSearchScrollHits(hits, "1234"); @@ -51,9 +54,7 @@ public class StreamQueriesTest { scrollId -> newSearchScrollHits(Collections.emptyList(), scrollId), // scrollIds -> clearScrollCalled.set(true)); - while (iterator.hasNext()) { - iterator.next(); - } + iterator.next(); iterator.close(); // then @@ -61,6 +62,27 @@ public class StreamQueriesTest { } + @Test // #1745 + @DisplayName("should call clearScroll when no more data is available") + void shouldCallClearScrollWhenNoMoreDataIsAvailable() { + + List> hits = new ArrayList<>(); + hits.add(getOneSearchHit()); + SearchScrollHits searchHits = newSearchScrollHits(hits, "1234"); + AtomicBoolean clearScrollCalled = new AtomicBoolean(false); + + SearchHitsIterator iterator = StreamQueries.streamResults( // + 0, // + searchHits, // + scrollId -> newSearchScrollHits(Collections.emptyList(), scrollId), // + scrollIds -> clearScrollCalled.set(true)); + + while (iterator.hasNext()) { + iterator.next(); + } + + assertThat(clearScrollCalled).isTrue(); + } private SearchHit getOneSearchHit() { return new SearchHit(null, null, null, 0, null, null, null, null, null, null, "one"); }