diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java index 7826837bb..15e32ca38 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java @@ -32,6 +32,7 @@ import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Collection; +import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.function.Function; @@ -82,6 +83,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.client.reactive.ReactorClientHttpConnector; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; import org.springframework.util.ReflectionUtils; import org.springframework.web.client.HttpServerErrorException; import org.springframework.web.reactive.function.BodyExtractors; @@ -360,9 +362,31 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch private Mono sendRequest(WebClient webClient, String logId, Request request, HttpHeaders headers) { RequestBodySpec requestBodySpec = webClient.method(HttpMethod.valueOf(request.getMethod().toUpperCase())) // - .uri(request.getEndpoint(), request.getParameters()) // + .uri(builder -> { + + builder = builder.path(request.getEndpoint()); + + if (!ObjectUtils.isEmpty(request.getParameters())) { + for (Entry entry : request.getParameters().entrySet()) { + builder = builder.queryParam(entry.getKey(), entry.getValue()); + } + } + return builder.build(); + }) // .attribute(ClientRequest.LOG_ID_ATTRIBUTE, logId) // - .headers(theHeaders -> theHeaders.addAll(headers)); + .headers(theHeaders -> { + + // add all the headers explicitly set + theHeaders.addAll(headers); + + // and now those that might be set on the request. + if (request.getOptions() != null) { + + if (!ObjectUtils.isEmpty(request.getOptions().getHeaders())) { + request.getOptions().getHeaders().forEach(it -> theHeaders.add(it.getName(), it.getValue())); + } + } + }); if (request.getEntity() != null) { diff --git a/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClientUnitTests.java b/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClientUnitTests.java index f7637d0d8..02a431f43 100644 --- a/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClientUnitTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClientUnitTests.java @@ -16,14 +16,14 @@ package org.springframework.data.elasticsearch.client.reactive; import static org.assertj.core.api.Assertions.*; -import static org.mockito.ArgumentMatchers.*; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; import static org.springframework.data.elasticsearch.client.reactive.ReactiveMockClientTestsUtils.MockWebClientProvider.Receive.*; import reactor.test.StepVerifier; +import java.net.URI; import java.util.Collections; -import java.util.Map; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.DocWriteResponse.Result; @@ -33,7 +33,9 @@ import org.elasticsearch.action.get.MultiGetRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.VersionType; import org.junit.Before; import org.junit.Test; import org.reactivestreams.Publisher; @@ -61,6 +63,29 @@ public class ReactiveElasticsearchClientUnitTests { client = new DefaultReactiveElasticsearchClient(hostProvider); } + @Test // DATAES-512 + public void sendRequestShouldCarryOnRequestParameters() { + + hostProvider.when(HOST).receiveDeleteOk(); + + DeleteRequest request = new DeleteRequest("index", "type", "id"); + request.version(1000); + request.versionType(VersionType.EXTERNAL); + request.timeout(TimeValue.timeValueMinutes(10)); + + client.delete(request) // + .then() // + .as(StepVerifier::create) // + .verifyComplete(); + + URI uri = hostProvider.when(HOST).captureUri(); + + assertThat(uri.getQuery()) // + .contains("version=1000") // + .contains("version_type=external") // + .contains("timeout=10m"); + } + // --> PING @Test @@ -74,9 +99,8 @@ public class ReactiveElasticsearchClientUnitTests { .as(StepVerifier::create) // .verifyComplete(); - hostProvider.when(HOST).exchange(requestBodyUriSpec -> { - verify(requestBodyUriSpec).uri(eq("/"), any(Map.class)); - }); + URI uri = hostProvider.when(HOST).captureUri(); + assertThat(uri.getRawPath()).isEqualTo("/"); } @Test // DATAES-488 @@ -116,9 +140,8 @@ public class ReactiveElasticsearchClientUnitTests { .as(StepVerifier::create) // .verifyComplete(); - hostProvider.when(HOST).exchange(requestBodyUriSpec -> { - verify(requestBodyUriSpec).uri(eq("/"), any(Map.class)); - }); + URI uri = hostProvider.when(HOST).captureUri(); + assertThat(uri.getRawPath()).isEqualTo("/"); } @Test // DATAES-488 @@ -151,9 +174,8 @@ public class ReactiveElasticsearchClientUnitTests { .verifyComplete(); verify(hostProvider.client(HOST)).method(HttpMethod.GET); - hostProvider.when(HOST).exchange(requestBodyUriSpec -> { - verify(requestBodyUriSpec).uri(eq("/twitter/_all/1"), any(Map.class)); - }); + URI uri = hostProvider.when(HOST).captureUri(); + assertThat(uri.getRawPath()).isEqualTo("/twitter/_all/1"); } @Test // DATAES-488 @@ -204,10 +226,11 @@ public class ReactiveElasticsearchClientUnitTests { verify(hostProvider.client(HOST)).method(HttpMethod.POST); hostProvider.when(HOST).exchange(requestBodyUriSpec -> { - - verify(requestBodyUriSpec).uri(eq("/_mget"), any(Map.class)); verify(requestBodyUriSpec).body(any(Publisher.class), any(Class.class)); }); + + URI uri = hostProvider.when(HOST).captureUri(); + assertThat(uri.getRawPath()).isEqualTo("/_mget"); } @Test // DATAES-488 @@ -287,9 +310,8 @@ public class ReactiveElasticsearchClientUnitTests { verify(hostProvider.client(HOST)).method(HttpMethod.HEAD); - hostProvider.when(HOST).exchange(requestBodyUriSpec -> { - verify(requestBodyUriSpec).uri(eq("/twitter/_all/1"), any(Map.class)); - }); + URI uri = hostProvider.when(HOST).captureUri(); + assertThat(uri.getRawPath()).isEqualTo("/twitter/_all/1"); } @Test // DATAES-488 @@ -329,10 +351,11 @@ public class ReactiveElasticsearchClientUnitTests { verify(hostProvider.client(HOST)).method(HttpMethod.PUT); hostProvider.when(HOST).exchange(requestBodyUriSpec -> { - - verify(requestBodyUriSpec).uri(eq("/twitter/10/_create"), any(Map.class)); verify(requestBodyUriSpec).contentType(MediaType.APPLICATION_JSON); }); + + URI uri = hostProvider.when(HOST).captureUri(); + assertThat(uri.getRawPath()).isEqualTo("/twitter/10/_create"); } @Test // DATAES-488 @@ -347,10 +370,11 @@ public class ReactiveElasticsearchClientUnitTests { verify(hostProvider.client(HOST)).method(HttpMethod.PUT); hostProvider.when(HOST).exchange(requestBodyUriSpec -> { - - verify(requestBodyUriSpec).uri(eq("/twitter/10"), any(Map.class)); verify(requestBodyUriSpec).contentType(MediaType.APPLICATION_JSON); }); + + URI uri = hostProvider.when(HOST).captureUri(); + assertThat(uri.getRawPath()).isEqualTo("/twitter/10"); } @Test // DATAES-488 @@ -401,10 +425,11 @@ public class ReactiveElasticsearchClientUnitTests { verify(hostProvider.client(HOST)).method(HttpMethod.POST); hostProvider.when(HOST).exchange(requestBodyUriSpec -> { - - verify(requestBodyUriSpec).uri(eq("/twitter/doc/1/_update"), any(Map.class)); verify(requestBodyUriSpec).contentType(MediaType.APPLICATION_JSON); }); + + URI uri = hostProvider.when(HOST).captureUri(); + assertThat(uri.getRawPath()).isEqualTo("/twitter/doc/1/_update"); } @Test // DATAES-488 @@ -450,9 +475,8 @@ public class ReactiveElasticsearchClientUnitTests { .verifyComplete(); verify(hostProvider.client(HOST)).method(HttpMethod.DELETE); - hostProvider.when(HOST).exchange(requestBodyUriSpec -> { - verify(requestBodyUriSpec).uri(eq("/twitter/doc/1"), any(Map.class)); - }); + URI uri = hostProvider.when(HOST).captureUri(); + assertThat(uri.getRawPath()).isEqualTo("/twitter/doc/1"); } @Test // DATAES-488 @@ -484,9 +508,8 @@ public class ReactiveElasticsearchClientUnitTests { client.search(new SearchRequest("twitter")).as(StepVerifier::create).verifyComplete(); verify(hostProvider.client(HOST)).method(HttpMethod.POST); - hostProvider.when(HOST).exchange(requestBodyUriSpec -> { - verify(requestBodyUriSpec).uri(eq("/twitter/_search"), any(Map.class)); - }); + URI uri = hostProvider.when(HOST).captureUri(); + assertThat(uri.getRawPath()).isEqualTo("/twitter/_search"); } @Test // DATAES-488 diff --git a/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java b/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java index c70b1d261..2cc6b9bd5 100644 --- a/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java +++ b/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java @@ -17,20 +17,28 @@ package org.springframework.data.elasticsearch.client.reactive; import static org.mockito.Mockito.*; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpMethod; +import org.springframework.web.util.DefaultUriBuilderFactory; +import org.springframework.web.util.UriBuilder; import reactor.core.publisher.Mono; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import org.mockito.Mockito; @@ -215,21 +223,20 @@ public class ReactiveMockClientTestsUtils { Mockito.when(headersUriSpec.uri(any(), any(Map.class))).thenReturn(headersUriSpec); Mockito.when(headersUriSpec.headers(any(Consumer.class))).thenReturn(headersUriSpec); Mockito.when(headersUriSpec.attribute(anyString(), anyString())).thenReturn(headersUriSpec); + Mockito.when(headersUriSpec.uri(any(Function.class))).thenReturn(headersUriSpec); - RequestBodyUriSpec bodyUriSpec = mock(RequestBodyUriSpec.class); - Mockito.when(webClient.method(any())).thenReturn(bodyUriSpec); - Mockito.when(bodyUriSpec.body(any())).thenReturn(headersUriSpec); - Mockito.when(bodyUriSpec.uri(any(), any(Map.class))).thenReturn(bodyUriSpec); - Mockito.when(bodyUriSpec.attribute(anyString(), anyString())).thenReturn(bodyUriSpec); - Mockito.when(bodyUriSpec.headers(any(Consumer.class))).thenReturn(bodyUriSpec); + RequestBodyUriSpec bodySpy = spy(WebClient.create().method(HttpMethod.POST)); + + Mockito.when(webClient.method(any())).thenReturn(bodySpy); + Mockito.when(bodySpy.body(any())).thenReturn(headersUriSpec); ClientResponse response = mock(ClientResponse.class); Mockito.when(headersUriSpec.exchange()).thenReturn(Mono.just(response)); - Mockito.when(bodyUriSpec.exchange()).thenReturn(Mono.just(response)); + Mockito.when(bodySpy.exchange()).thenReturn(Mono.just(response)); Mockito.when(response.statusCode()).thenReturn(HttpStatus.ACCEPTED); headersUriSpecMap.putIfAbsent(key, headersUriSpec); - bodyUriSpecMap.putIfAbsent(key, bodyUriSpec); + bodyUriSpecMap.putIfAbsent(key, bodySpy); responseMap.putIfAbsent(key, response); return webClient; @@ -273,6 +280,21 @@ public class ReactiveMockClientTestsUtils { Receive exchange(Consumer bodySpec); + default URI captureUri() { + + Set capturingSet = new LinkedHashSet(); + + exchange(requestBodyUriSpec -> { + + ArgumentCaptor> fkt = ArgumentCaptor.forClass(Function.class); + verify(requestBodyUriSpec).uri(fkt.capture()); + + capturingSet.add(fkt.getValue().apply(new DefaultUriBuilderFactory().builder())); + }); + + return capturingSet.iterator().next(); + } + default Receive receiveJsonFromFile(String file) { return receive(Receive::json) //