diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java index c92383346..ab2c55a52 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java @@ -145,7 +145,7 @@ import org.springframework.web.reactive.function.client.WebClient.RequestBodySpe */ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearchClient, Indices { - private final HostProvider hostProvider; + private final HostProvider hostProvider; private final RequestCreator requestCreator; private Supplier headersSupplier = () -> HttpHeaders.EMPTY; @@ -155,7 +155,7 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch * * @param hostProvider must not be {@literal null}. */ - public DefaultReactiveElasticsearchClient(HostProvider hostProvider) { + public DefaultReactiveElasticsearchClient(HostProvider hostProvider) { this(hostProvider, new DefaultRequestCreator()); } @@ -166,7 +166,7 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch * @param hostProvider must not be {@literal null}. * @param requestCreator must not be {@literal null}. */ - public DefaultReactiveElasticsearchClient(HostProvider hostProvider, RequestCreator requestCreator) { + public DefaultReactiveElasticsearchClient(HostProvider hostProvider, RequestCreator requestCreator) { Assert.notNull(hostProvider, "HostProvider must not be null"); Assert.notNull(requestCreator, "RequestCreator must not be null"); @@ -224,7 +224,7 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch WebClientProvider provider = getWebClientProvider(clientConfiguration); - HostProvider hostProvider = HostProvider.provider(provider, clientConfiguration.getHeadersSupplier(), + HostProvider hostProvider = HostProvider.provider(provider, clientConfiguration.getHeadersSupplier(), clientConfiguration.getEndpoints().toArray(new InetSocketAddress[0])); DefaultReactiveElasticsearchClient client = new DefaultReactiveElasticsearchClient(hostProvider, requestCreator); diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/HostProvider.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/HostProvider.java index 2474315dc..bff8f1e08 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/HostProvider.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/HostProvider.java @@ -34,19 +34,20 @@ import org.springframework.web.reactive.function.client.WebClient; * * @author Christoph Strobl * @author Mark Paluch + * @author Peter-Josef Meisch * @since 3.2 */ -public interface HostProvider { +public interface HostProvider> { /** * Create a new {@link HostProvider} best suited for the given {@link WebClientProvider} and number of hosts. * * @param clientProvider must not be {@literal null} . - * @param headersSupplier to supply custom headers, must not be {@literal null} + * @param headersSupplier to supply custom headers, must not be {@literal null} * @param endpoints must not be {@literal null} nor empty. * @return new instance of {@link HostProvider}. */ - static HostProvider provider(WebClientProvider clientProvider, Supplier headersSupplier, + static HostProvider provider(WebClientProvider clientProvider, Supplier headersSupplier, InetSocketAddress... endpoints) { Assert.notNull(clientProvider, "WebClientProvider must not be null"); @@ -55,7 +56,7 @@ public interface HostProvider { if (endpoints.length == 1) { return new SingleNodeHostProvider(clientProvider, headersSupplier, endpoints[0]); } else { - return new MultiNodeHostProvider(clientProvider,headersSupplier, endpoints); + return new MultiNodeHostProvider(clientProvider, headersSupplier, endpoints); } } diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProvider.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProvider.java index 86e65ffd9..df98f04a1 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProvider.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProvider.java @@ -42,15 +42,17 @@ import org.springframework.web.reactive.function.client.WebClient; * * @author Christoph Strobl * @author Mark Paluch + * @author Peter-Josef Meisch * @since 3.2 */ -class MultiNodeHostProvider implements HostProvider { +class MultiNodeHostProvider implements HostProvider { private final WebClientProvider clientProvider; private final Supplier headersSupplier; private final Map hosts; - MultiNodeHostProvider(WebClientProvider clientProvider, Supplier headersSupplier, InetSocketAddress... endpoints) { + MultiNodeHostProvider(WebClientProvider clientProvider, Supplier headersSupplier, + InetSocketAddress... endpoints) { this.clientProvider = clientProvider; this.headersSupplier = headersSupplier; @@ -136,16 +138,19 @@ class MultiNodeHostProvider implements HostProvider { .map(ElasticsearchHost::getEndpoint) // .flatMap(host -> { - Mono exchange = createWebClient(host) // + Mono clientResponseMono = createWebClient(host) // .head().uri("/") // .headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) // - .exchange().doOnError(throwable -> { + .exchangeToMono(Mono::just) // + .doOnError(throwable -> { hosts.put(host, new ElasticsearchHost(host, State.OFFLINE)); clientProvider.getErrorListener().accept(throwable); }); - return Mono.just(host).zipWith(exchange - .flatMap(it -> it.releaseBody().thenReturn(it.statusCode().isError() ? State.OFFLINE : State.ONLINE))); + return Mono.just(host) // + .zipWith( // + clientResponseMono.flatMap(it -> it.releaseBody() // + .thenReturn(it.statusCode().isError() ? State.OFFLINE : State.ONLINE))); }) // .onErrorContinue((throwable, o) -> clientProvider.getErrorListener().accept(throwable)); } diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProvider.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProvider.java index 24360cac7..583bea527 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProvider.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProvider.java @@ -32,9 +32,10 @@ import org.springframework.web.reactive.function.client.WebClient; * * @author Christoph Strobl * @author Mark Paluch + * @author Peter-Josef Meisch * @since 3.2 */ -class SingleNodeHostProvider implements HostProvider { +class SingleNodeHostProvider implements HostProvider { private final WebClientProvider clientProvider; private final Supplier headersSupplier; @@ -60,20 +61,18 @@ class SingleNodeHostProvider implements HostProvider { return createWebClient(endpoint) // .head().uri("/") // .headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) // - .exchange() // - .flatMap(it -> { + .exchangeToMono(it -> { if (it.statusCode().isError()) { state = ElasticsearchHost.offline(endpoint); } else { state = ElasticsearchHost.online(endpoint); } - return it.releaseBody().thenReturn(state); + return Mono.just(state); }).onErrorResume(throwable -> { state = ElasticsearchHost.offline(endpoint); clientProvider.getErrorListener().accept(throwable); return Mono.just(state); - }) // - .map(it -> new ClusterInformation(Collections.singleton(it))); + }).map(elasticsearchHost -> new ClusterInformation(Collections.singleton(elasticsearchHost))); } /* diff --git a/src/test/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProviderUnitTests.java b/src/test/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProviderUnitTests.java index 1b2f89ea8..a6377e7c9 100644 --- a/src/test/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProviderUnitTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProviderUnitTests.java @@ -18,6 +18,7 @@ package org.springframework.data.elasticsearch.client.reactive; import static org.assertj.core.api.Assertions.*; import static org.mockito.Mockito.*; +import org.mockito.invocation.InvocationOnMock; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -30,8 +31,11 @@ import org.springframework.data.elasticsearch.client.reactive.ReactiveMockClient import org.springframework.data.elasticsearch.client.reactive.ReactiveMockClientTestsUtils.MockWebClientProvider.Receive; import org.springframework.web.reactive.function.client.ClientResponse; +import java.util.function.Function; + /** * @author Christoph Strobl + * @author Peter-Josef Meisch */ public class MultiNodeHostProviderUnitTests { @@ -39,82 +43,85 @@ public class MultiNodeHostProviderUnitTests { static final String HOST_2 = ":9201"; static final String HOST_3 = ":9202"; - MockDelegatingElasticsearchHostProvider mock; - MultiNodeHostProvider provider; + MockDelegatingElasticsearchHostProvider multiNodeDelegatingHostProvider; + MultiNodeHostProvider delegateHostProvider; @BeforeEach public void setUp() { - mock = ReactiveMockClientTestsUtils.multi(HOST_1, HOST_2, HOST_3); - provider = mock.getDelegate(); + multiNodeDelegatingHostProvider = ReactiveMockClientTestsUtils.multi(HOST_1, HOST_2, HOST_3); + delegateHostProvider = multiNodeDelegatingHostProvider.getDelegate(); } @Test // DATAES-488 public void refreshHostStateShouldUpdateNodeStateCorrectly() { - mock.when(HOST_1).receive(Receive::error); - mock.when(HOST_2).receive(Receive::ok); - mock.when(HOST_3).receive(Receive::ok); + multiNodeDelegatingHostProvider.when(HOST_1).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_2).receive(Receive::ok); + multiNodeDelegatingHostProvider.when(HOST_3).receive(Receive::ok); - provider.clusterInfo().as(StepVerifier::create).expectNextCount(1).verifyComplete(); + delegateHostProvider.clusterInfo().as(StepVerifier::create).expectNextCount(1).verifyComplete(); - assertThat(provider.getCachedHostState()).extracting(ElasticsearchHost::getState).containsExactly(State.OFFLINE, - State.ONLINE, State.ONLINE); + assertThat(delegateHostProvider.getCachedHostState()).extracting(ElasticsearchHost::getState) + .containsExactly(State.OFFLINE, State.ONLINE, State.ONLINE); } @Test // DATAES-488 public void getActiveReturnsFirstActiveHost() { - mock.when(HOST_1).receive(Receive::error); - mock.when(HOST_2).receive(Receive::ok); - mock.when(HOST_3).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_1).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_2).receive(Receive::ok); + multiNodeDelegatingHostProvider.when(HOST_3).receive(Receive::error); - provider.getActive().as(StepVerifier::create).expectNext(mock.client(HOST_2)).verifyComplete(); + delegateHostProvider.getActive().as(StepVerifier::create).expectNext(multiNodeDelegatingHostProvider.client(HOST_2)) + .verifyComplete(); } @Test // DATAES-488 public void getActiveErrorsWhenNoActiveHostFound() { - mock.when(HOST_1).receive(Receive::error); - mock.when(HOST_2).receive(Receive::error); - mock.when(HOST_3).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_1).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_2).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_3).receive(Receive::error); - provider.getActive().as(StepVerifier::create).expectError(IllegalStateException.class); + delegateHostProvider.getActive().as(StepVerifier::create).expectError(IllegalStateException.class); } @Test // DATAES-488 public void lazyModeDoesNotResolveHostsTwice() { - mock.when(HOST_1).receive(Receive::error); - mock.when(HOST_2).receive(Receive::ok); - mock.when(HOST_3).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_1).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_2).receive(Receive::ok); + multiNodeDelegatingHostProvider.when(HOST_3).receive(Receive::error); - provider.clusterInfo().as(StepVerifier::create).expectNextCount(1).verifyComplete(); + delegateHostProvider.clusterInfo().as(StepVerifier::create).expectNextCount(1).verifyComplete(); - provider.getActive(Verification.LAZY).as(StepVerifier::create).expectNext(mock.client(HOST_2)).verifyComplete(); + delegateHostProvider.getActive(Verification.LAZY).as(StepVerifier::create) + .expectNext(multiNodeDelegatingHostProvider.client(HOST_2)).verifyComplete(); - verify(mock.client(":9201")).head(); + verify(multiNodeDelegatingHostProvider.client(":9201")).head(); } @Test // DATAES-488 public void alwaysModeDoesNotResolveHostsTwice() { - mock.when(HOST_1).receive(Receive::error); - mock.when(HOST_2).receive(Receive::ok); - mock.when(HOST_3).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_1).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_2).receive(Receive::ok); + multiNodeDelegatingHostProvider.when(HOST_3).receive(Receive::error); - provider.clusterInfo().as(StepVerifier::create).expectNextCount(1).verifyComplete(); + delegateHostProvider.clusterInfo().as(StepVerifier::create).expectNextCount(1).verifyComplete(); - provider.getActive(Verification.ACTIVE).as(StepVerifier::create).expectNext(mock.client(HOST_2)).verifyComplete(); + delegateHostProvider.getActive(Verification.ACTIVE).as(StepVerifier::create) + .expectNext(multiNodeDelegatingHostProvider.client(HOST_2)).verifyComplete(); - verify(mock.client(HOST_2), times(2)).head(); + verify(multiNodeDelegatingHostProvider.client(HOST_2), times(2)).head(); } @Test // DATAES-488 public void triesDeadHostsIfNoActiveFound() { - mock.when(HOST_1).receive(Receive::error); - mock.when(HOST_2).get(requestHeadersUriSpec -> { + multiNodeDelegatingHostProvider.when(HOST_1).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_2).get(requestHeadersUriSpec -> { ClientResponse response1 = mock(ClientResponse.class); when(response1.releaseBody()).thenReturn(Mono.empty()); @@ -124,17 +131,29 @@ public class MultiNodeHostProviderUnitTests { when(response2.releaseBody()).thenReturn(Mono.empty()); Receive.ok(response2); - when(requestHeadersUriSpec.exchange()).thenReturn(Mono.just(response1), Mono.just(response2)); + when(requestHeadersUriSpec.exchangeToMono(any()))// + .thenAnswer(invocation -> getAnswer(invocation, response1)) // + .thenAnswer(invocation -> getAnswer(invocation, response2)); }); - mock.when(HOST_3).receive(Receive::error); + multiNodeDelegatingHostProvider.when(HOST_3).receive(Receive::error); - provider.clusterInfo().as(StepVerifier::create).expectNextCount(1).verifyComplete(); - assertThat(provider.getCachedHostState()).extracting(ElasticsearchHost::getState).containsExactly(State.OFFLINE, - State.OFFLINE, State.OFFLINE); + delegateHostProvider.clusterInfo().as(StepVerifier::create).expectNextCount(1).verifyComplete(); + assertThat(delegateHostProvider.getCachedHostState()).extracting(ElasticsearchHost::getState) + .containsExactly(State.OFFLINE, State.OFFLINE, State.OFFLINE); - provider.getActive().as(StepVerifier::create).expectNext(mock.client(HOST_2)).verifyComplete(); + delegateHostProvider.getActive().as(StepVerifier::create).expectNext(multiNodeDelegatingHostProvider.client(HOST_2)) + .verifyComplete(); - verify(mock.client(HOST_2), times(2)).head(); + verify(multiNodeDelegatingHostProvider.client(HOST_2), times(2)).head(); + } + + private Mono getAnswer(InvocationOnMock invocation, ClientResponse response) { + final Function> responseHandler = invocation.getArgument(0); + + if (responseHandler != null) { + return responseHandler.apply(response); + } + return Mono.empty(); } } diff --git a/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClientUnitTests.java b/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClientUnitTests.java index 4b0b4e79d..f62251c26 100644 --- a/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClientUnitTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClientUnitTests.java @@ -62,7 +62,7 @@ public class ReactiveElasticsearchClientUnitTests { static final String HOST = ":9200"; - MockDelegatingElasticsearchHostProvider hostProvider; + MockDelegatingElasticsearchHostProvider> hostProvider; ReactiveElasticsearchClient client; @BeforeEach diff --git a/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java b/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java index 7fc1e5ca4..fed886655 100644 --- a/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java +++ b/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java @@ -75,24 +75,23 @@ public class ReactiveMockClientTestsUtils { return provider(hosts); } - public static MockDelegatingElasticsearchHostProvider provider(String... hosts) { + public static > MockDelegatingElasticsearchHostProvider provider(String... hosts) { ErrorCollector errorCollector = new ErrorCollector(); MockWebClientProvider clientProvider = new MockWebClientProvider(errorCollector); - HostProvider delegate = null; + T delegate; if (hosts.length == 1) { - - delegate = new SingleNodeHostProvider(clientProvider, HttpHeaders::new, getInetSocketAddress(hosts[0])) {}; + // noinspection unchecked + delegate = (T) new SingleNodeHostProvider(clientProvider, HttpHeaders::new, getInetSocketAddress(hosts[0])) {}; } else { - - delegate = new MultiNodeHostProvider(clientProvider,HttpHeaders::new, Arrays.stream(hosts) + // noinspection unchecked + delegate = (T) new MultiNodeHostProvider(clientProvider, HttpHeaders::new, Arrays.stream(hosts) .map(ReactiveMockClientTestsUtils::getInetSocketAddress).toArray(InetSocketAddress[]::new)) {}; } - return new MockDelegatingElasticsearchHostProvider(HttpHeaders.EMPTY, clientProvider, errorCollector, delegate, + return new MockDelegatingElasticsearchHostProvider<>(HttpHeaders.EMPTY, clientProvider, errorCollector, delegate, null); - } private static InetSocketAddress getInetSocketAddress(String hostAndPort) { @@ -113,16 +112,18 @@ public class ReactiveMockClientTestsUtils { } } - public static class MockDelegatingElasticsearchHostProvider implements HostProvider { + public static class MockDelegatingElasticsearchHostProvider> implements HostProvider { + private final HttpHeaders httpHeaders; private final T delegate; private final MockWebClientProvider clientProvider; private final ErrorCollector errorCollector; - private @Nullable String activeDefaultHost; + private @Nullable final String activeDefaultHost; public MockDelegatingElasticsearchHostProvider(HttpHeaders httpHeaders, MockWebClientProvider clientProvider, - ErrorCollector errorCollector, T delegate, String activeDefaultHost) { + ErrorCollector errorCollector, T delegate, @Nullable String activeDefaultHost) { + this.httpHeaders = httpHeaders; this.errorCollector = errorCollector; this.clientProvider = clientProvider; this.delegate = delegate; @@ -187,25 +188,23 @@ public class ReactiveMockClientTestsUtils { } public MockDelegatingElasticsearchHostProvider withActiveDefaultHost(String host) { - return new MockDelegatingElasticsearchHostProvider(HttpHeaders.EMPTY, clientProvider, errorCollector, delegate, - host); + return new MockDelegatingElasticsearchHostProvider<>(httpHeaders, clientProvider, errorCollector, delegate, host); } } public static class MockWebClientProvider implements WebClientProvider { - private final Object lock = new Object(); private final Consumer errorListener; - private Map clientMap; - private Map headersUriSpecMap; - private Map bodyUriSpecMap; - private Map responseMap; + private final Map clientMap; + private final Map headersUriSpecMap; + private final Map bodyUriSpecMap; + private final Map responseMap; public MockWebClientProvider(Consumer errorListener) { this.errorListener = errorListener; - this.clientMap = new LinkedHashMap<>(); + this.clientMap = new ConcurrentHashMap<>(); this.headersUriSpecMap = new LinkedHashMap<>(); this.bodyUriSpecMap = new LinkedHashMap<>(); this.responseMap = new LinkedHashMap<>(); @@ -218,40 +217,49 @@ public class ReactiveMockClientTestsUtils { @Override public WebClient get(InetSocketAddress endpoint) { - synchronized (lock) { + return clientMap.computeIfAbsent(endpoint, key -> { - return clientMap.computeIfAbsent(endpoint, key -> { + WebClient webClient = mock(WebClient.class); - WebClient webClient = mock(WebClient.class); + RequestHeadersUriSpec headersUriSpec = mock(RequestHeadersUriSpec.class); + Mockito.when(headersUriSpec.uri(any(String.class))).thenReturn(headersUriSpec); + Mockito.when(headersUriSpec.uri(any(), any(Map.class))).thenReturn(headersUriSpec); + Mockito.when(headersUriSpec.headers(any(Consumer.class))).thenReturn(headersUriSpec); + Mockito.when(headersUriSpec.attribute(anyString(), anyString())).thenReturn(headersUriSpec); + Mockito.when(headersUriSpec.uri(any(Function.class))).thenReturn(headersUriSpec); + headersUriSpecMap.putIfAbsent(key, headersUriSpec); - RequestHeadersUriSpec headersUriSpec = mock(RequestHeadersUriSpec.class); - Mockito.when(webClient.get()).thenReturn(headersUriSpec); - Mockito.when(webClient.head()).thenReturn(headersUriSpec); + ClientResponse response = mock(ClientResponse.class); + Mockito.when(response.statusCode()).thenReturn(HttpStatus.ACCEPTED); + Mockito.when(response.releaseBody()).thenReturn(Mono.empty()); + Mockito.when(headersUriSpec.exchangeToMono(any())).thenAnswer(invocation -> { + final Function> responseHandler = invocation.getArgument(0); - Mockito.when(headersUriSpec.uri(any(String.class))).thenReturn(headersUriSpec); - Mockito.when(headersUriSpec.uri(any(), any(Map.class))).thenReturn(headersUriSpec); - Mockito.when(headersUriSpec.headers(any(Consumer.class))).thenReturn(headersUriSpec); - Mockito.when(headersUriSpec.attribute(anyString(), anyString())).thenReturn(headersUriSpec); - Mockito.when(headersUriSpec.uri(any(Function.class))).thenReturn(headersUriSpec); - - RequestBodyUriSpec bodySpy = spy(WebClient.create().method(HttpMethod.POST)); - - Mockito.when(webClient.method(any())).thenReturn(bodySpy); - Mockito.when(bodySpy.body(any())).thenReturn(headersUriSpec); - - ClientResponse response = mock(ClientResponse.class); - Mockito.when(headersUriSpec.exchange()).thenReturn(Mono.just(response)); - Mockito.when(bodySpy.exchange()).thenReturn(Mono.just(response)); - Mockito.when(response.statusCode()).thenReturn(HttpStatus.ACCEPTED); - Mockito.when(response.releaseBody()).thenReturn(Mono.empty()); - - headersUriSpecMap.putIfAbsent(key, headersUriSpec); - bodyUriSpecMap.putIfAbsent(key, bodySpy); - responseMap.putIfAbsent(key, response); - - return webClient; + if (responseHandler != null) { + return responseHandler.apply(response); + } + return Mono.empty(); }); - } + responseMap.putIfAbsent(key, response); + + RequestBodyUriSpec bodySpy = spy(WebClient.create().method(HttpMethod.POST)); + Mockito.when(bodySpy.body(any())).thenReturn(headersUriSpec); + Mockito.when(bodySpy.exchangeToMono(any())).thenAnswer(invocation -> { + final Function> responseHandler = invocation.getArgument(0); + + if (responseHandler != null) { + return responseHandler.apply(response); + } + return Mono.empty(); + }); + bodyUriSpecMap.putIfAbsent(key, bodySpy); + + Mockito.when(webClient.get()).thenReturn(headersUriSpec); + Mockito.when(webClient.head()).thenReturn(headersUriSpec); + Mockito.when(webClient.method(any())).thenReturn(bodySpy); + + return webClient; + }); } @Override @@ -299,18 +307,20 @@ public class ReactiveMockClientTestsUtils { WebClient client(); } + @SuppressWarnings("UnusedReturnValue") public interface Send extends Receive, Client { - Receive get(Consumer headerSpec); + Receive get(Consumer> headerSpec); Receive exchange(Consumer bodySpec); default URI captureUri() { - Set capturingSet = new LinkedHashSet(); + Set capturingSet = new LinkedHashSet<>(); exchange(requestBodyUriSpec -> { + // noinspection unchecked ArgumentCaptor> fkt = ArgumentCaptor.forClass(Function.class); verify(requestBodyUriSpec).uri(fkt.capture()); @@ -354,9 +364,8 @@ public class ReactiveMockClientTestsUtils { default Receive receiveGetByIdNotFound() { return receiveJsonFromFile("get-by-id-no-hit") // - .receive(response -> { - Mockito.when(response.statusCode()).thenReturn(HttpStatus.ACCEPTED, HttpStatus.NOT_FOUND); - }); + .receive( + response -> Mockito.when(response.statusCode()).thenReturn(HttpStatus.ACCEPTED, HttpStatus.NOT_FOUND)); } default Receive receiveGetById() { @@ -380,9 +389,8 @@ public class ReactiveMockClientTestsUtils { default Receive updateFail() { return receiveJsonFromFile("update-error-not-found") // - .receive(response -> { - Mockito.when(response.statusCode()).thenReturn(HttpStatus.ACCEPTED, HttpStatus.NOT_FOUND); - }); + .receive( + response -> Mockito.when(response.statusCode()).thenReturn(HttpStatus.ACCEPTED, HttpStatus.NOT_FOUND)); } default Receive receiveBulkOk() { @@ -445,14 +453,14 @@ public class ReactiveMockClientTestsUtils { } } - class CallbackImpl implements Send, Receive { + static class CallbackImpl implements Send, Receive { WebClient client; - RequestHeadersUriSpec headersUriSpec; + RequestHeadersUriSpec headersUriSpec; RequestBodyUriSpec bodyUriSpec; ClientResponse responseDelegate; - public CallbackImpl(WebClient client, RequestHeadersUriSpec headersUriSpec, RequestBodyUriSpec bodyUriSpec, + public CallbackImpl(WebClient client, RequestHeadersUriSpec headersUriSpec, RequestBodyUriSpec bodyUriSpec, ClientResponse responseDelegate) { this.client = client; @@ -462,7 +470,7 @@ public class ReactiveMockClientTestsUtils { } @Override - public Receive get(Consumer uriSpec) { + public Receive get(Consumer> uriSpec) { uriSpec.accept(headersUriSpec); return this; diff --git a/src/test/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProviderUnitTests.java b/src/test/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProviderUnitTests.java index a27e22389..b41fafbda 100644 --- a/src/test/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProviderUnitTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProviderUnitTests.java @@ -30,6 +30,7 @@ import org.springframework.data.elasticsearch.client.reactive.ReactiveMockClient /** * @author Christoph Strobl + * @author Peter-Josef Meisch */ public class SingleNodeHostProviderUnitTests {