diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java index a7e69670f..e91b4443e 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java @@ -281,15 +281,23 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch scheme = "https"; } - ReactorClientHttpConnector connector = new ReactorClientHttpConnector(httpClient); - WebClientProvider provider = WebClientProvider.create(scheme, connector); + WebClientProvider provider = WebClientProvider.create(scheme, new ReactorClientHttpConnector(httpClient)); if (clientConfiguration.getPathPrefix() != null) { provider = provider.withPathPrefix(clientConfiguration.getPathPrefix()); } - provider = provider.withDefaultHeaders(clientConfiguration.getDefaultHeaders()) // - .withWebClientConfigurer(clientConfiguration.getWebClientConfigurer()); + provider = provider // + .withDefaultHeaders(clientConfiguration.getDefaultHeaders()) // + .withWebClientConfigurer(clientConfiguration.getWebClientConfigurer()) // + .withRequestConfigurer(requestHeadersSpec -> requestHeadersSpec.headers(httpHeaders -> { + HttpHeaders suppliedHeaders = clientConfiguration.getHeadersSupplier().get(); + + if (suppliedHeaders != null && suppliedHeaders != HttpHeaders.EMPTY) { + httpHeaders.addAll(suppliedHeaders); + } + })); + return provider; } @@ -584,12 +592,6 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch request.getOptions().getHeaders().forEach(it -> theHeaders.add(it.getName(), it.getValue())); } } - - // plus the ones from the supplier - HttpHeaders suppliedHeaders = headersSupplier.get(); - if (suppliedHeaders != null && suppliedHeaders != HttpHeaders.EMPTY) { - theHeaders.addAll(suppliedHeaders); - } }); if (request.getEntity() != null) { @@ -599,8 +601,8 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch ClientLogger.logRequest(logId, request.getMethod().toUpperCase(), request.getEndpoint(), request.getParameters(), body::get); - requestBodySpec.contentType(MediaType.valueOf(request.getEntity().getContentType().getValue())); - requestBodySpec.body(Mono.fromSupplier(body), String.class); + requestBodySpec.contentType(MediaType.valueOf(request.getEntity().getContentType().getValue())) + .body(Mono.fromSupplier(body), String.class); } else { ClientLogger.logRequest(logId, request.getMethod().toUpperCase(), request.getEndpoint(), request.getParameters()); } diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultWebClientProvider.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultWebClientProvider.java index b94718646..8a486658a 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultWebClientProvider.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultWebClientProvider.java @@ -48,6 +48,7 @@ class DefaultWebClientProvider implements WebClientProvider { private final HttpHeaders headers; private final @Nullable String pathPrefix; private final Function webClientConfigurer; + private final Consumer> requestConfigurer; /** * Create new {@link DefaultWebClientProvider} with empty {@link HttpHeaders} and no-op {@literal error listener}. @@ -56,7 +57,7 @@ class DefaultWebClientProvider implements WebClientProvider { * @param connector can be {@literal null}. */ DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector connector) { - this(scheme, connector, e -> {}, HttpHeaders.EMPTY, null, Function.identity()); + this(scheme, connector, e -> {}, HttpHeaders.EMPTY, null, Function.identity(), requestHeadersSpec -> {}); } /** @@ -66,18 +67,21 @@ class DefaultWebClientProvider implements WebClientProvider { * @param connector can be {@literal null}. * @param errorListener must not be {@literal null}. * @param headers must not be {@literal null}. - * @param pathPrefix can be {@literal null} + * @param pathPrefix can be {@literal null}. * @param webClientConfigurer must not be {@literal null}. + * @param requestConfigurer must not be {@literal null}. */ private DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector connector, Consumer errorListener, HttpHeaders headers, @Nullable String pathPrefix, - Function webClientConfigurer) { + Function webClientConfigurer, Consumer> requestConfigurer) { Assert.notNull(scheme, "Scheme must not be null! A common scheme would be 'http'."); Assert.notNull(errorListener, "errorListener must not be null! You may want use a no-op one 'e -> {}' instead."); Assert.notNull(headers, "headers must not be null! Think about using 'HttpHeaders.EMPTY' as an alternative."); Assert.notNull(webClientConfigurer, "webClientConfigurer must not be null! You may want use a no-op one 'Function.identity()' instead."); + Assert.notNull(requestConfigurer, + "requestConfigurer must not be null! You may want use a no-op one 'r -> {}' instead.\""); this.cachedClients = new ConcurrentHashMap<>(); this.scheme = scheme; @@ -86,6 +90,7 @@ class DefaultWebClientProvider implements WebClientProvider { this.headers = headers; this.pathPrefix = pathPrefix; this.webClientConfigurer = webClientConfigurer; + this.requestConfigurer = requestConfigurer; } @Override @@ -106,6 +111,7 @@ class DefaultWebClientProvider implements WebClientProvider { return this.errorListener; } + @Nullable @Override public String getPathPrefix() { return pathPrefix; @@ -120,7 +126,17 @@ class DefaultWebClientProvider implements WebClientProvider { merged.addAll(this.headers); merged.addAll(headers); - return new DefaultWebClientProvider(scheme, connector, errorListener, merged, pathPrefix, webClientConfigurer); + return new DefaultWebClientProvider(scheme, connector, errorListener, merged, pathPrefix, webClientConfigurer, + requestConfigurer); + } + + @Override + public WebClientProvider withRequestConfigurer(Consumer> requestConfigurer) { + + Assert.notNull(requestConfigurer, "requestConfigurer must not be null."); + + return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer, + requestConfigurer); } @Override @@ -129,7 +145,8 @@ class DefaultWebClientProvider implements WebClientProvider { Assert.notNull(errorListener, "Error listener must not be null."); Consumer listener = this.errorListener.andThen(errorListener); - return new DefaultWebClientProvider(scheme, this.connector, listener, headers, pathPrefix, webClientConfigurer); + return new DefaultWebClientProvider(scheme, this.connector, listener, headers, pathPrefix, webClientConfigurer, + requestConfigurer); } @Override @@ -137,18 +154,21 @@ class DefaultWebClientProvider implements WebClientProvider { Assert.notNull(pathPrefix, "pathPrefix must not be null."); return new DefaultWebClientProvider(this.scheme, this.connector, this.errorListener, this.headers, pathPrefix, - webClientConfigurer); + webClientConfigurer, requestConfigurer); } @Override public WebClientProvider withWebClientConfigurer(Function webClientConfigurer) { - return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer); + return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer, + requestConfigurer); } protected WebClient createWebClientForSocketAddress(InetSocketAddress socketAddress) { - Builder builder = WebClient.builder().defaultHeaders(it -> it.addAll(getDefaultHeaders())); + Builder builder = WebClient.builder() // + .defaultHeaders(it -> it.addAll(getDefaultHeaders())) // + .defaultRequest(requestConfigurer); if (connector != null) { builder = builder.clientConnector(connector); diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/HostProvider.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/HostProvider.java index a5e49a62e..f122738a6 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/HostProvider.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/HostProvider.java @@ -54,9 +54,9 @@ public interface HostProvider> { Assert.notEmpty(endpoints, "Please provide at least one endpoint to connect to."); if (endpoints.length == 1) { - return new SingleNodeHostProvider(clientProvider, headersSupplier, endpoints[0]); + return new SingleNodeHostProvider(clientProvider, endpoints[0]); } else { - return new MultiNodeHostProvider(clientProvider, headersSupplier, endpoints); + return new MultiNodeHostProvider(clientProvider, endpoints); } } diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProvider.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProvider.java index 4250c6bdb..02c989483 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProvider.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProvider.java @@ -28,14 +28,12 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.data.elasticsearch.client.ElasticsearchHost; import org.springframework.data.elasticsearch.client.ElasticsearchHost.State; import org.springframework.data.elasticsearch.client.NoReachableHostException; -import org.springframework.http.HttpHeaders; import org.springframework.lang.Nullable; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.WebClient; @@ -53,14 +51,11 @@ class MultiNodeHostProvider implements HostProvider { private final static Logger LOG = LoggerFactory.getLogger(MultiNodeHostProvider.class); private final WebClientProvider clientProvider; - private final Supplier headersSupplier; private final Map hosts; - MultiNodeHostProvider(WebClientProvider clientProvider, Supplier headersSupplier, - InetSocketAddress... endpoints) { + MultiNodeHostProvider(WebClientProvider clientProvider, InetSocketAddress... endpoints) { this.clientProvider = clientProvider; - this.headersSupplier = headersSupplier; this.hosts = new ConcurrentHashMap<>(); for (InetSocketAddress endpoint : endpoints) { this.hosts.put(endpoint, new ElasticsearchHost(endpoint, State.UNKNOWN)); @@ -166,7 +161,6 @@ class MultiNodeHostProvider implements HostProvider { Mono clientResponseMono = createWebClient(host) // .head().uri("/") // - .headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) // .exchangeToMono(Mono::just) // .timeout(Duration.ofSeconds(1)) // .doOnError(throwable -> { diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProvider.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProvider.java index 227b29201..b576d3653 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProvider.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProvider.java @@ -19,12 +19,10 @@ import reactor.core.publisher.Mono; import java.net.InetSocketAddress; import java.util.Collections; -import java.util.function.Supplier; import org.springframework.data.elasticsearch.client.ElasticsearchHost; import org.springframework.data.elasticsearch.client.ElasticsearchHost.State; import org.springframework.data.elasticsearch.client.NoReachableHostException; -import org.springframework.http.HttpHeaders; import org.springframework.web.reactive.function.client.WebClient; /** @@ -38,15 +36,12 @@ import org.springframework.web.reactive.function.client.WebClient; class SingleNodeHostProvider implements HostProvider { private final WebClientProvider clientProvider; - private final Supplier headersSupplier; private final InetSocketAddress endpoint; private volatile ElasticsearchHost state; - SingleNodeHostProvider(WebClientProvider clientProvider, Supplier headersSupplier, - InetSocketAddress endpoint) { + SingleNodeHostProvider(WebClientProvider clientProvider, InetSocketAddress endpoint) { this.clientProvider = clientProvider; - this.headersSupplier = headersSupplier; this.endpoint = endpoint; this.state = new ElasticsearchHost(this.endpoint, State.UNKNOWN); } @@ -60,7 +55,6 @@ class SingleNodeHostProvider implements HostProvider { return createWebClient(endpoint) // .head().uri("/") // - .headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) // .exchangeToMono(it -> { if (it.statusCode().isError()) { state = ElasticsearchHost.offline(endpoint); diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/WebClientProvider.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/WebClientProvider.java index 994667536..5c092a248 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/WebClientProvider.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/WebClientProvider.java @@ -101,7 +101,7 @@ public interface WebClientProvider { /** * Obtain the {@link String pathPrefix} to be used. - * + * * @return the pathPrefix if set. * @since 4.0 */ @@ -126,7 +126,7 @@ public interface WebClientProvider { /** * Create a new instance of {@link WebClientProvider} where HTTP requests are called with the given path prefix. - * + * * @param pathPrefix Path prefix to add to requests * @return new instance of {@link WebClientProvider} * @since 4.0 @@ -136,10 +136,20 @@ public interface WebClientProvider { /** * Create a new instance of {@link WebClientProvider} calling the given {@link Function} to configure the * {@link WebClient}. - * + * * @param webClientConfigurer configuration function * @return new instance of {@link WebClientProvider} * @since 4.0 */ WebClientProvider withWebClientConfigurer(Function webClientConfigurer); + + /** + * Create a new instance of {@link WebClientProvider} calling the given {@link Consumer} to configure the requests of + * this {@link WebClient}. + * + * @param requestConfigurer request configuration callback + * @return new instance of {@link WebClientProvider} + * @since 4.3 + */ + WebClientProvider withRequestConfigurer(Consumer> requestConfigurer); } diff --git a/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java b/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java index 8d17705f9..9f9cecc95 100644 --- a/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java +++ b/src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java @@ -83,10 +83,10 @@ public class ReactiveMockClientTestsUtils { if (hosts.length == 1) { // noinspection unchecked - delegate = (T) new SingleNodeHostProvider(clientProvider, HttpHeaders::new, getInetSocketAddress(hosts[0])) {}; + delegate = (T) new SingleNodeHostProvider(clientProvider, getInetSocketAddress(hosts[0])) {}; } else { // noinspection unchecked - delegate = (T) new MultiNodeHostProvider(clientProvider, HttpHeaders::new, Arrays.stream(hosts) + delegate = (T) new MultiNodeHostProvider(clientProvider, Arrays.stream(hosts) .map(ReactiveMockClientTestsUtils::getInetSocketAddress).toArray(InetSocketAddress[]::new)) {}; } @@ -297,6 +297,11 @@ public class ReactiveMockClientTestsUtils { throw new UnsupportedOperationException("not implemented"); } + @Override + public WebClientProvider withRequestConfigurer(Consumer> requestConfigurer) { + throw new UnsupportedOperationException("not implemented"); + } + public Send when(String host) { InetSocketAddress inetSocketAddress = getInetSocketAddress(host); return new CallbackImpl(get(host), headersUriSpecMap.get(inetSocketAddress),