diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java index 1a787b7ad..86bdc04f3 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java @@ -22,9 +22,7 @@ import io.netty.handler.ssl.IdentityCipherSuiteFilter; import io.netty.handler.ssl.JdkSslContext; import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.handler.timeout.WriteTimeoutHandler; -import reactor.core.Exceptions; import reactor.core.publisher.Flux; -import reactor.core.publisher.FluxProcessor; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import reactor.netty.http.client.HttpClient; @@ -107,6 +105,9 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.suggest.Suggest; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + import org.springframework.data.elasticsearch.client.ClientConfiguration; import org.springframework.data.elasticsearch.client.ClientLogger; import org.springframework.data.elasticsearch.client.ElasticsearchHost; @@ -467,9 +468,7 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch } Sinks.Many requests = Sinks.many().unicast().onBackpressureBuffer(); - - FluxProcessor inbound = FluxProcessor - .fromSink(Sinks.many().unicast().onBackpressureBuffer()); + Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); Flux exchange = requests.asFlux().flatMap(it -> { @@ -490,12 +489,12 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch scrollState -> { - Flux searchHits = inbound. handle((searchResponse, sink) -> { + Flux searchHits = inbound.asFlux(). handle((searchResponse, sink) -> { scrollState.updateScrollId(searchResponse.getScrollId()); if (isEmpty(searchResponse.getHits())) { - inbound.onComplete(); + inbound.tryEmitComplete(); requests.tryEmitComplete(); } else { @@ -504,15 +503,15 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch SearchScrollRequest searchScrollRequest = new SearchScrollRequest(scrollState.getScrollId()) .scroll(scrollTimeout); - tryEmitNext(requests, searchScrollRequest); + requests.emitNext(searchScrollRequest, Sinks.EmitFailureHandler.FAIL_FAST); } }).map(SearchResponse::getHits) // .flatMap(Flux::fromIterable); return searchHits.doOnSubscribe(ignore -> { - exchange.subscribe(inbound); - tryEmitNext(requests, searchRequest); + exchange.subscribe(new SinkSubscriber(inbound)); + requests.emitNext(searchRequest, Sinks.EmitFailureHandler.FAIL_FAST); }); }, state -> cleanupScroll(headers, state), // @@ -520,14 +519,6 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch state -> cleanupScroll(headers, state)); // } - private void tryEmitNext(Sinks.Many sink, ActionRequest request) { - - Sinks.Emission emission = sink.tryEmitNext(request); - - if (emission == Sinks.Emission.FAIL_OVERFLOW) { - sink.tryEmitError(Exceptions.failWithOverflow("Backpressure overflow during Sinks.Many#emitNext")); - } - } private static boolean isEmpty(@Nullable SearchHits hits) { return hits != null && hits.getHits() != null && hits.getHits().length == 0; @@ -964,5 +955,34 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch } } + private static class SinkSubscriber implements Subscriber { + + private final Sinks.Many inbound; + + public SinkSubscriber(Sinks.Many inbound) { + this.inbound = inbound; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(SearchResponse searchResponse) { + inbound.emitNext(searchResponse, Sinks.EmitFailureHandler.FAIL_FAST); + } + + @Override + public void onError(Throwable t) { + inbound.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); + } + + @Override + public void onComplete() { + inbound.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST); + } + } + // endregion }