diff --git a/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java index 8f54e6c83..3e809613b 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java @@ -258,7 +258,11 @@ public abstract class AbstractElasticsearchTemplate implements ElasticsearchOper long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis(); + // noinspection ConstantConditions + int maxCount = query.isLimiting() ? query.getMaxResults() : 0; + return StreamQueries.streamResults( // + maxCount, // searchScrollStart(scrollTimeInMillis, query, clazz, index), // scrollId -> searchScrollContinue(scrollId, scrollTimeInMillis, clazz, index), // this::searchScrollClear); diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java index 0e4347a42..8a1eeaa71 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java @@ -40,6 +40,8 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.suggest.SuggestBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter; import org.springframework.data.elasticsearch.core.document.DocumentAdapters; import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse; @@ -88,6 +90,8 @@ import org.springframework.util.Assert; */ public class ElasticsearchRestTemplate extends AbstractElasticsearchTemplate { + private static final Logger LOGGER = LoggerFactory.getLogger(ElasticsearchRestTemplate.class); + private RestHighLevelClient client; private ElasticsearchExceptionTranslator exceptionTranslator; @@ -300,9 +304,13 @@ public class ElasticsearchRestTemplate extends AbstractElasticsearchTemplate { @Override public void searchScrollClear(List scrollIds) { - ClearScrollRequest request = new ClearScrollRequest(); - request.scrollIds(scrollIds); - execute(client -> client.clearScroll(request, RequestOptions.DEFAULT)); + try { + ClearScrollRequest request = new ClearScrollRequest(); + request.scrollIds(scrollIds); + execute(client -> client.clearScroll(request, RequestOptions.DEFAULT)); + } catch (Exception e) { + LOGGER.warn("Could not clear scroll: {}", e.getMessage()); + } } @Override diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java index 1a58eed4b..c09827f79 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java @@ -86,6 +86,7 @@ import org.springframework.util.Assert; public class ElasticsearchTemplate extends AbstractElasticsearchTemplate { private static final Logger QUERY_LOGGER = LoggerFactory .getLogger("org.springframework.data.elasticsearch.core.QUERY"); + private static final Logger LOGGER = LoggerFactory.getLogger(ElasticsearchTemplate.class); private Client client; @Nullable private String searchTimeout; @@ -322,7 +323,11 @@ public class ElasticsearchTemplate extends AbstractElasticsearchTemplate { @Override public void searchScrollClear(List scrollIds) { - client.prepareClearScroll().setScrollIds(scrollIds).execute().actionGet(); + try { + client.prepareClearScroll().setScrollIds(scrollIds).execute().actionGet(); + } catch (Exception e) { + LOGGER.warn("Could not clear scroll: {}", e.getMessage()); + } } @Override diff --git a/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java b/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java index 866eefdc4..12f605755 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java @@ -18,6 +18,7 @@ package org.springframework.data.elasticsearch.core; import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; @@ -38,13 +39,15 @@ abstract class StreamQueries { /** * Stream query results using {@link SearchScrollHits}. * + * @param maxCount the maximum number of entities to return, a value of 0 means that all available entities are + * returned * @param searchHits the initial hits * @param continueScrollFunction function to continue scrolling applies to the current scrollId. * @param clearScrollConsumer consumer to clear the scroll context by accepting the scrollIds to clear. - * @param + * @param the entity type * @return the {@link SearchHitsIterator}. */ - static SearchHitsIterator streamResults(SearchScrollHits searchHits, + static SearchHitsIterator streamResults(int maxCount, SearchScrollHits searchHits, Function> continueScrollFunction, Consumer> clearScrollConsumer) { Assert.notNull(searchHits, "searchHits must not be null."); @@ -59,20 +62,14 @@ abstract class StreamQueries { return new SearchHitsIterator() { - // As we couldn't retrieve single result with scroll, store current hits. - private volatile Iterator> scrollHits = searchHits.iterator(); - private volatile boolean continueScroll = scrollHits.hasNext(); + private volatile AtomicInteger currentCount = new AtomicInteger(); + private volatile Iterator> currentScrollHits = searchHits.iterator(); + private volatile boolean continueScroll = currentScrollHits.hasNext(); private volatile ScrollState scrollState = new ScrollState(searchHits.getScrollId()); @Override public void close() { - - try { - clearScrollConsumer.accept(scrollState.getScrollIds()); - } finally { - scrollHits = null; - scrollState = null; - } + clearScrollConsumer.accept(scrollState.getScrollIds()); } @Override @@ -99,24 +96,25 @@ abstract class StreamQueries { @Override public boolean hasNext() { - if (!continueScroll) { + if (!continueScroll || (maxCount > 0 && currentCount.get() >= maxCount)) { return false; } - if (!scrollHits.hasNext()) { + if (!currentScrollHits.hasNext()) { SearchScrollHits nextPage = continueScrollFunction.apply(scrollState.getScrollId()); - scrollHits = nextPage.iterator(); + currentScrollHits = nextPage.iterator(); scrollState.updateScrollId(nextPage.getScrollId()); - continueScroll = scrollHits.hasNext(); + continueScroll = currentScrollHits.hasNext(); } - return scrollHits.hasNext(); + return currentScrollHits.hasNext(); } @Override public SearchHit next() { if (hasNext()) { - return scrollHits.next(); + currentCount.incrementAndGet(); + return currentScrollHits.next(); } throw new NoSuchElementException(); } diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java index e5eed540e..5d426a46b 100755 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java @@ -77,7 +77,7 @@ import org.springframework.data.elasticsearch.annotations.ScriptedField; import org.springframework.data.elasticsearch.core.geo.GeoPoint; import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; import org.springframework.data.elasticsearch.core.query.*; -import org.springframework.data.util.CloseableIterator; +import org.springframework.data.util.StreamUtils; import org.springframework.lang.Nullable; /** @@ -1298,27 +1298,33 @@ public abstract class ElasticsearchTemplateTests { assertThat(sampleEntities).hasSize(30); } - @Test // DATAES-167 - public void shouldReturnResultsWithStreamForGivenCriteriaQuery() { + @Test // DATAES-167, DATAES-831 + public void shouldReturnAllResultsWithStreamForGivenCriteriaQuery() { - // given - List entities = createSampleEntitiesWithMessage("Test message", 30); - - // when - operations.bulkIndex(entities, index); + operations.bulkIndex(createSampleEntitiesWithMessage("Test message", 30), index); indexOperations.refresh(); - - // then CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria()); criteriaQuery.setPageable(PageRequest.of(0, 10)); - CloseableIterator> stream = operations.searchForStream(criteriaQuery, SampleEntity.class, - index); - List> sampleEntities = new ArrayList<>(); - while (stream.hasNext()) { - sampleEntities.add(stream.next()); - } - assertThat(sampleEntities).hasSize(30); + long count = StreamUtils + .createStreamFromIterator(operations.searchForStream(criteriaQuery, SampleEntity.class, index)).count(); + + assertThat(count).isEqualTo(30); + } + + @Test // DATAES-831 + void shouldLimitStreamResultToRequestedSize() { + + operations.bulkIndex(createSampleEntitiesWithMessage("Test message", 30), index); + indexOperations.refresh(); + + CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria()); + criteriaQuery.setMaxResults(10); + + long count = StreamUtils + .createStreamFromIterator(operations.searchForStream(criteriaQuery, SampleEntity.class, index)).count(); + + assertThat(count).isEqualTo(10); } private static List createSampleEntitiesWithMessage(String message, int numberOfEntities) { @@ -3128,8 +3134,8 @@ public abstract class ElasticsearchTemplateTests { operations.refresh(OptimisticEntity.class); List queries = singletonList(queryForOne(saved.getId())); - List> retrievedHits = operations.multiSearch(queries, - OptimisticEntity.class, operations.getIndexCoordinatesFor(OptimisticEntity.class)); + List> retrievedHits = operations.multiSearch(queries, OptimisticEntity.class, + operations.getIndexCoordinatesFor(OptimisticEntity.class)); OptimisticEntity retrieved = retrievedHits.get(0).getSearchHit(0).getContent(); assertThatSeqNoPrimaryTermIsFilled(retrieved); @@ -3162,8 +3168,7 @@ public abstract class ElasticsearchTemplateTests { operations.save(forEdit1); forEdit2.setMessage("It'll be great"); - assertThatThrownBy(() -> operations.save(forEdit2)) - .isInstanceOf(OptimisticLockingFailureException.class); + assertThatThrownBy(() -> operations.save(forEdit2)).isInstanceOf(OptimisticLockingFailureException.class); } @Test // DATAES-799 @@ -3179,8 +3184,7 @@ public abstract class ElasticsearchTemplateTests { operations.save(forEdit1); forEdit2.setMessage("It'll be great"); - assertThatThrownBy(() -> operations.save(forEdit2)) - .isInstanceOf(OptimisticLockingFailureException.class); + assertThatThrownBy(() -> operations.save(forEdit2)).isInstanceOf(OptimisticLockingFailureException.class); } @Test // DATAES-799 diff --git a/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java b/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java index f55c18309..51cc2589f 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.jupiter.api.Test; +import org.springframework.data.util.StreamUtils; /** * @author Sascha Woo @@ -45,6 +46,7 @@ public class StreamQueriesTest { // when SearchHitsIterator iterator = StreamQueries.streamResults( // + 0, // searchHits, // scrollId -> newSearchScrollHits(Collections.emptyList(), scrollId), // scrollIds -> clearScrollCalled.set(true)); @@ -70,6 +72,7 @@ public class StreamQueriesTest { // when SearchHitsIterator iterator = StreamQueries.streamResults( // + 0, // searchHits, // scrollId -> newSearchScrollHits(Collections.emptyList(), scrollId), // scrollId -> {}); @@ -90,10 +93,12 @@ public class StreamQueriesTest { Collections.singletonList(new SearchHit(null, 0, null, null, "one")), "s-2"); SearchScrollHits searchHits4 = newSearchScrollHits(Collections.emptyList(), "s-3"); - Iterator> searchScrollHitsIterator = Arrays.asList(searchHits1, searchHits2, searchHits3,searchHits4).iterator(); + Iterator> searchScrollHitsIterator = Arrays + .asList(searchHits1, searchHits2, searchHits3, searchHits4).iterator(); List clearedScrollIds = new ArrayList<>(); SearchHitsIterator iterator = StreamQueries.streamResults( // + 0, // searchScrollHitsIterator.next(), // scrollId -> searchScrollHitsIterator.next(), // scrollIds -> clearedScrollIds.addAll(scrollIds)); @@ -106,6 +111,56 @@ public class StreamQueriesTest { assertThat(clearedScrollIds).isEqualTo(Arrays.asList("s-1", "s-2", "s-3")); } + @Test // DATAES-831 + void shouldReturnAllForRequestedSizeOf0() { + + SearchScrollHits searchHits1 = newSearchScrollHits( + Collections.singletonList(new SearchHit(null, 0, null, null, "one")), "s-1"); + SearchScrollHits searchHits2 = newSearchScrollHits( + Collections.singletonList(new SearchHit(null, 0, null, null, "one")), "s-2"); + SearchScrollHits searchHits3 = newSearchScrollHits( + Collections.singletonList(new SearchHit(null, 0, null, null, "one")), "s-2"); + SearchScrollHits searchHits4 = newSearchScrollHits(Collections.emptyList(), "s-3"); + + Iterator> searchScrollHitsIterator = Arrays + .asList(searchHits1, searchHits2, searchHits3, searchHits4).iterator(); + + SearchHitsIterator iterator = StreamQueries.streamResults( // + 0, // + searchScrollHitsIterator.next(), // + scrollId -> searchScrollHitsIterator.next(), // + scrollIds -> {}); + + long count = StreamUtils.createStreamFromIterator(iterator).count(); + + assertThat(count).isEqualTo(3); + } + + @Test // DATAES-831 + void shouldOnlyReturnRequestedCount() { + + SearchScrollHits searchHits1 = newSearchScrollHits( + Collections.singletonList(new SearchHit(null, 0, null, null, "one")), "s-1"); + SearchScrollHits searchHits2 = newSearchScrollHits( + Collections.singletonList(new SearchHit(null, 0, null, null, "one")), "s-2"); + SearchScrollHits searchHits3 = newSearchScrollHits( + Collections.singletonList(new SearchHit(null, 0, null, null, "one")), "s-2"); + SearchScrollHits searchHits4 = newSearchScrollHits(Collections.emptyList(), "s-3"); + + Iterator> searchScrollHitsIterator = Arrays + .asList(searchHits1, searchHits2, searchHits3, searchHits4).iterator(); + + SearchHitsIterator iterator = StreamQueries.streamResults( // + 2, // + searchScrollHitsIterator.next(), // + scrollId -> searchScrollHitsIterator.next(), // + scrollIds -> {}); + + long count = StreamUtils.createStreamFromIterator(iterator).count(); + + assertThat(count).isEqualTo(2); + } + private SearchScrollHits newSearchScrollHits(List> hits, String scrollId) { return new SearchHitsImpl(hits.size(), TotalHitsRelation.EQUAL_TO, 0, scrollId, hits, null); }