diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java index 7e5aff6d4..be5fd5c82 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java @@ -36,10 +36,7 @@ import java.net.ConnectException; import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; import java.time.Duration; -import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; -import java.util.List; import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.TimeUnit; @@ -93,7 +90,6 @@ import org.elasticsearch.index.reindex.BulkByScrollResponse; import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.search.Scroll; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.aggregations.Aggregation; @@ -105,6 +101,7 @@ import org.springframework.data.elasticsearch.client.NoReachableHostException; import org.springframework.data.elasticsearch.client.reactive.HostProvider.Verification; import org.springframework.data.elasticsearch.client.reactive.ReactiveElasticsearchClient.Indices; import org.springframework.data.elasticsearch.client.util.NamedXContents; +import org.springframework.data.elasticsearch.client.util.ScrollState; import org.springframework.data.util.Lazy; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -115,7 +112,6 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; import org.springframework.util.ReflectionUtils; -import org.springframework.util.StringUtils; import org.springframework.web.client.HttpClientErrorException; import org.springframework.web.client.HttpServerErrorException; import org.springframework.web.reactive.function.BodyExtractors; @@ -835,8 +831,7 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch .error(new ElasticsearchStatusException(content, RestStatus.fromCode(response.statusCode().value()))); } return Mono.just(content); - }) - .doOnNext(it -> ClientLogger.logResponse(logId, response.statusCode(), it)) // + }).doOnNext(it -> ClientLogger.logResponse(logId, response.statusCode(), it)) // .flatMap(content -> doDecode(response, responseType, content)); } @@ -893,42 +888,5 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch } } - /** - * Mutable state object holding scrollId to be used for {@link SearchScrollRequest#scroll(Scroll)} - * - * @author Christoph Strobl - * @since 3.2 - */ - private static class ScrollState { - - private final Object lock = new Object(); - - private final List pastIds = new ArrayList<>(1); - @Nullable private String scrollId; - - @Nullable - String getScrollId() { - return scrollId; - } - - List getScrollIds() { - - synchronized (lock) { - return Collections.unmodifiableList(new ArrayList<>(pastIds)); - } - } - - void updateScrollId(String scrollId) { - - if (StringUtils.hasText(scrollId)) { - - synchronized (lock) { - - this.scrollId = scrollId; - pastIds.add(scrollId); - } - } - } - } // endregion } diff --git a/src/main/java/org/springframework/data/elasticsearch/client/util/ScrollState.java b/src/main/java/org/springframework/data/elasticsearch/client/util/ScrollState.java new file mode 100644 index 000000000..1b3ad1b6e --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/client/util/ScrollState.java @@ -0,0 +1,72 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.client.util; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import org.elasticsearch.action.search.SearchScrollRequest; +import org.elasticsearch.search.Scroll; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * Mutable state object holding scrollId to be used for {@link SearchScrollRequest#scroll(Scroll)} + * + * @author Christoph Strobl + * @author Peter-Josef Meisch + * @since 3.2 + */ +public class ScrollState { + + private final Object lock = new Object(); + + private final Set pastIds = new LinkedHashSet<>(); + @Nullable private String scrollId; + + public ScrollState() {} + + public ScrollState(String scrollId) { + updateScrollId(scrollId); + } + + @Nullable + public String getScrollId() { + return scrollId; + } + + public List getScrollIds() { + + synchronized (lock) { + return Collections.unmodifiableList(new ArrayList<>(pastIds)); + } + } + + public void updateScrollId(String scrollId) { + + if (StringUtils.hasText(scrollId)) { + + synchronized (lock) { + + this.scrollId = scrollId; + pastIds.add(scrollId); + } + } + } +} diff --git a/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java index 2cc8633e6..8f54e6c83 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java @@ -17,6 +17,7 @@ package org.springframework.data.elasticsearch.core; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -341,7 +342,14 @@ public abstract class AbstractElasticsearchTemplate implements ElasticsearchOper /* * internal use only, not for public API */ - abstract protected void searchScrollClear(String scrollId); + protected void searchScrollClear(String scrollId) { + searchScrollClear(Collections.singletonList(scrollId)); + } + + /* + * internal use only, not for public API + */ + abstract protected void searchScrollClear(List scrollIds); abstract protected MultiSearchResponse.Item[] getMultiSearchResult(MultiSearchRequest request); // endregion diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java index 3d6fc25c5..0e4347a42 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java @@ -299,9 +299,9 @@ public class ElasticsearchRestTemplate extends AbstractElasticsearchTemplate { } @Override - public void searchScrollClear(String scrollId) { + public void searchScrollClear(List scrollIds) { ClearScrollRequest request = new ClearScrollRequest(); - request.addScrollId(scrollId); + request.scrollIds(scrollIds); execute(client -> client.clearScroll(request, RequestOptions.DEFAULT)); } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java index 066fc8bcf..1a58eed4b 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java @@ -321,8 +321,8 @@ public class ElasticsearchTemplate extends AbstractElasticsearchTemplate { } @Override - public void searchScrollClear(String scrollId) { - client.prepareClearScroll().addScrollId(scrollId).execute().actionGet(); + public void searchScrollClear(List scrollIds) { + client.prepareClearScroll().setScrollIds(scrollIds).execute().actionGet(); } @Override diff --git a/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java b/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java index 77573c8f1..866eefdc4 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java @@ -16,11 +16,13 @@ package org.springframework.data.elasticsearch.core; import java.util.Iterator; +import java.util.List; import java.util.NoSuchElementException; import java.util.function.Consumer; import java.util.function.Function; import org.elasticsearch.search.aggregations.Aggregations; +import org.springframework.data.elasticsearch.client.util.ScrollState; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -38,12 +40,12 @@ abstract class StreamQueries { * * @param searchHits the initial hits * @param continueScrollFunction function to continue scrolling applies to the current scrollId. - * @param clearScrollConsumer consumer to clear the scroll context by accepting the current scrollId. + * @param clearScrollConsumer consumer to clear the scroll context by accepting the scrollIds to clear. * @param * @return the {@link SearchHitsIterator}. */ static SearchHitsIterator streamResults(SearchScrollHits searchHits, - Function> continueScrollFunction, Consumer clearScrollConsumer) { + Function> continueScrollFunction, Consumer> clearScrollConsumer) { Assert.notNull(searchHits, "searchHits must not be null."); Assert.notNull(searchHits.getScrollId(), "scrollId of searchHits must not be null."); @@ -59,17 +61,17 @@ abstract class StreamQueries { // As we couldn't retrieve single result with scroll, store current hits. private volatile Iterator> scrollHits = searchHits.iterator(); - private volatile String scrollId = searchHits.getScrollId(); private volatile boolean continueScroll = scrollHits.hasNext(); + private volatile ScrollState scrollState = new ScrollState(searchHits.getScrollId()); @Override public void close() { try { - clearScrollConsumer.accept(scrollId); + clearScrollConsumer.accept(scrollState.getScrollIds()); } finally { scrollHits = null; - scrollId = null; + scrollState = null; } } @@ -102,9 +104,9 @@ abstract class StreamQueries { } if (!scrollHits.hasNext()) { - SearchScrollHits nextPage = continueScrollFunction.apply(scrollId); + SearchScrollHits nextPage = continueScrollFunction.apply(scrollState.getScrollId()); scrollHits = nextPage.iterator(); - scrollId = nextPage.getScrollId(); + scrollState.updateScrollId(nextPage.getScrollId()); continueScroll = scrollHits.hasNext(); } @@ -127,6 +129,5 @@ abstract class StreamQueries { } // utility constructor - private StreamQueries() { - } + private StreamQueries() {} } diff --git a/src/test/java/org/springframework/data/elasticsearch/client/util/ScrollStateTest.java b/src/test/java/org/springframework/data/elasticsearch/client/util/ScrollStateTest.java new file mode 100644 index 000000000..d93d6e75a --- /dev/null +++ b/src/test/java/org/springframework/data/elasticsearch/client/util/ScrollStateTest.java @@ -0,0 +1,52 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.client.util; + +import static org.assertj.core.api.Assertions.*; + +import java.util.Arrays; + +import org.junit.jupiter.api.Test; + +/** + * @author Peter-Josef Meisch + */ +class ScrollStateTest { + + @Test // DATAES-817 + void shouldReturnLastSetScrollId() { + ScrollState scrollState = new ScrollState(); + + scrollState.updateScrollId("id-1"); + scrollState.updateScrollId("id-2"); + + assertThat(scrollState.getScrollId()).isEqualTo("id-2"); + } + + @Test + void shouldReturnUniqueListOfUsedScrollIdsInCorrectOrder() { + + ScrollState scrollState = new ScrollState(); + + scrollState.updateScrollId("id-1"); + scrollState.updateScrollId("id-2"); + scrollState.updateScrollId("id-1"); + scrollState.updateScrollId("id-3"); + scrollState.updateScrollId("id-2"); + + assertThat(scrollState.getScrollIds()).isEqualTo(Arrays.asList("id-1", "id-2", "id-3")); + } +} diff --git a/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java b/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java index 5a481c510..f55c18309 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java @@ -18,19 +18,17 @@ package org.springframework.data.elasticsearch.core; import static org.assertj.core.api.Assertions.*; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; -import org.elasticsearch.search.aggregations.Aggregations; import org.junit.jupiter.api.Test; -import org.springframework.data.domain.PageImpl; -import org.springframework.data.domain.Pageable; -import org.springframework.lang.Nullable; /** * @author Sascha Woo + * @author Peter-Josef Meisch */ public class StreamQueriesTest { @@ -41,15 +39,15 @@ public class StreamQueriesTest { List> hits = new ArrayList<>(); hits.add(new SearchHit(null, 0, null, null, "one")); - SearchScrollHits searchHits = newSearchScrollHits(hits); + SearchScrollHits searchHits = newSearchScrollHits(hits, "1234"); AtomicBoolean clearScrollCalled = new AtomicBoolean(false); // when SearchHitsIterator iterator = StreamQueries.streamResults( // searchHits, // - scrollId -> newSearchScrollHits(Collections.emptyList()), // - scrollId -> clearScrollCalled.set(true)); + scrollId -> newSearchScrollHits(Collections.emptyList(), scrollId), // + scrollIds -> clearScrollCalled.set(true)); while (iterator.hasNext()) { iterator.next(); @@ -68,21 +66,47 @@ public class StreamQueriesTest { List> hits = new ArrayList<>(); hits.add(new SearchHit(null, 0, null, null, "one")); - SearchScrollHits searchHits = newSearchScrollHits(hits); + SearchScrollHits searchHits = newSearchScrollHits(hits, "1234"); // when SearchHitsIterator iterator = StreamQueries.streamResults( // searchHits, // - scrollId -> newSearchScrollHits(Collections.emptyList()), // - scrollId -> { - }); + scrollId -> newSearchScrollHits(Collections.emptyList(), scrollId), // + scrollId -> {}); // then assertThat(iterator.getTotalHits()).isEqualTo(1); } - private SearchScrollHits newSearchScrollHits(List> hits) { - return new SearchHitsImpl(hits.size(), TotalHitsRelation.EQUAL_TO, 0, "1234", hits, null); + @Test // DATAES-817 + void shouldClearAllScrollIds() { + + SearchScrollHits searchHits1 = newSearchScrollHits( + Collections.singletonList(new SearchHit(null, 0, null, null, "one")), "s-1"); + SearchScrollHits searchHits2 = newSearchScrollHits( + Collections.singletonList(new SearchHit(null, 0, null, null, "one")), "s-2"); + SearchScrollHits searchHits3 = newSearchScrollHits( + Collections.singletonList(new SearchHit(null, 0, null, null, "one")), "s-2"); + SearchScrollHits searchHits4 = newSearchScrollHits(Collections.emptyList(), "s-3"); + + Iterator> searchScrollHitsIterator = Arrays.asList(searchHits1, searchHits2, searchHits3,searchHits4).iterator(); + + List clearedScrollIds = new ArrayList<>(); + SearchHitsIterator iterator = StreamQueries.streamResults( // + searchScrollHitsIterator.next(), // + scrollId -> searchScrollHitsIterator.next(), // + scrollIds -> clearedScrollIds.addAll(scrollIds)); + + while (iterator.hasNext()) { + iterator.next(); + } + iterator.close(); + + assertThat(clearedScrollIds).isEqualTo(Arrays.asList("s-1", "s-2", "s-3")); + } + + private SearchScrollHits newSearchScrollHits(List> hits, String scrollId) { + return new SearchHitsImpl(hits.size(), TotalHitsRelation.EQUAL_TO, 0, scrollId, hits, null); } }