diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java index 336088f8b..f588ac5d9 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java @@ -78,6 +78,7 @@ import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.Scroll; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; import org.reactivestreams.Publisher; import org.springframework.data.elasticsearch.ElasticsearchException; import org.springframework.data.elasticsearch.client.ClientConfiguration; @@ -92,6 +93,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.client.reactive.ReactorClientHttpConnector; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; import org.springframework.util.ReflectionUtils; @@ -339,41 +341,55 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch } throw new IllegalArgumentException( - String.format("Cannot handle '%s'. Please make sure to use a 'SearchRequest' or 'SearchScrollRequest'.")); + String.format("Cannot handle '%s'. Please make sure to use a 'SearchRequest' or 'SearchScrollRequest'.", it)); }); - ScrollState state = new ScrollState(); + return Flux.usingWhen(Mono.fromSupplier(ScrollState::new), - Flux searchHits = inbound.doOnNext(searchResponse -> { - state.updateScrollId(searchResponse.getScrollId()); - }). handle((searchResponse, sink) -> { + scrollState -> { - if (searchResponse.getHits() != null && searchResponse.getHits().getHits() != null - && searchResponse.getHits().getHits().length == 0) { + Flux searchHits = inbound. handle((searchResponse, sink) -> { - inbound.onComplete(); - outbound.onComplete(); + scrollState.updateScrollId(searchResponse.getScrollId()); + if (isEmpty(searchResponse.getHits())) { - } else { + inbound.onComplete(); + outbound.onComplete(); - sink.next(searchResponse); + } else { - SearchScrollRequest searchScrollRequest = new SearchScrollRequest(state.getScrollId()).scroll(scrollTimeout); - request.next(searchScrollRequest); - } + sink.next(searchResponse); - }).map(SearchResponse::getHits) // - .flatMap(Flux::fromIterable) // - .doOnComplete(() -> { + SearchScrollRequest searchScrollRequest = new SearchScrollRequest(scrollState.getScrollId()) + .scroll(scrollTimeout); + request.next(searchScrollRequest); + } - ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); - clearScrollRequest.scrollIds(state.getScrollIds()); + }).map(SearchResponse::getHits) // + .flatMap(Flux::fromIterable); - // just send the request, resources get cleaned up anyways after scrollTimeout has been reached. - sendRequest(clearScrollRequest, RequestCreator.clearScroll(), ClearScrollResponse.class, headers).subscribe(); - }); + return searchHits.doOnSubscribe(ignore -> exchange.subscribe(inbound)); - return searchHits.doOnSubscribe(ignore -> exchange.subscribe(inbound)); + }, state -> cleanupScroll(headers, state), // + state -> cleanupScroll(headers, state), // + state -> cleanupScroll(headers, state)); // + } + + private static boolean isEmpty(@Nullable SearchHits hits) { + return hits != null && hits.getHits() != null && hits.getHits().length == 0; + } + + private Publisher cleanupScroll(HttpHeaders headers, ScrollState state) { + + if (state.getScrollIds().isEmpty()) { + return Mono.empty(); + } + + ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); + clearScrollRequest.scrollIds(state.getScrollIds()); + + // just send the request, resources get cleaned up anyways after scrollTimeout has been reached. + return sendRequest(clearScrollRequest, RequestCreator.clearScroll(), ClearScrollResponse.class, headers); } /* @@ -645,17 +661,20 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch */ private static class ScrollState { - private Object lock = new Object(); + private final Object lock = new Object(); + private final List pastIds = new ArrayList<>(1); private String scrollId; - private List pastIds = new ArrayList<>(1); String getScrollId() { return scrollId; } List getScrollIds() { - return Collections.unmodifiableList(pastIds); + + synchronized (lock) { + return Collections.unmodifiableList(new ArrayList<>(pastIds)); + } } void updateScrollId(String scrollId) { @@ -669,6 +688,5 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch } } } - } }