Refactor DefaultReactiveElasticsearchClient to do request customization with the WebClient. (#1795)

Original Pull Request #1795 
Closes #1794
This commit is contained in:
Peter-Josef Meisch 2021-04-30 06:48:07 +02:00 committed by GitHub
parent f8fbf7721a
commit 775bf66401
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 66 additions and 41 deletions

View File

@ -281,15 +281,23 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch
scheme = "https"; scheme = "https";
} }
ReactorClientHttpConnector connector = new ReactorClientHttpConnector(httpClient); WebClientProvider provider = WebClientProvider.create(scheme, new ReactorClientHttpConnector(httpClient));
WebClientProvider provider = WebClientProvider.create(scheme, connector);
if (clientConfiguration.getPathPrefix() != null) { if (clientConfiguration.getPathPrefix() != null) {
provider = provider.withPathPrefix(clientConfiguration.getPathPrefix()); provider = provider.withPathPrefix(clientConfiguration.getPathPrefix());
} }
provider = provider.withDefaultHeaders(clientConfiguration.getDefaultHeaders()) // provider = provider //
.withWebClientConfigurer(clientConfiguration.getWebClientConfigurer()); .withDefaultHeaders(clientConfiguration.getDefaultHeaders()) //
.withWebClientConfigurer(clientConfiguration.getWebClientConfigurer()) //
.withRequestConfigurer(requestHeadersSpec -> requestHeadersSpec.headers(httpHeaders -> {
HttpHeaders suppliedHeaders = clientConfiguration.getHeadersSupplier().get();
if (suppliedHeaders != null && suppliedHeaders != HttpHeaders.EMPTY) {
httpHeaders.addAll(suppliedHeaders);
}
}));
return provider; return provider;
} }
@ -584,12 +592,6 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch
request.getOptions().getHeaders().forEach(it -> theHeaders.add(it.getName(), it.getValue())); request.getOptions().getHeaders().forEach(it -> theHeaders.add(it.getName(), it.getValue()));
} }
} }
// plus the ones from the supplier
HttpHeaders suppliedHeaders = headersSupplier.get();
if (suppliedHeaders != null && suppliedHeaders != HttpHeaders.EMPTY) {
theHeaders.addAll(suppliedHeaders);
}
}); });
if (request.getEntity() != null) { if (request.getEntity() != null) {
@ -599,8 +601,8 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch
ClientLogger.logRequest(logId, request.getMethod().toUpperCase(), request.getEndpoint(), request.getParameters(), ClientLogger.logRequest(logId, request.getMethod().toUpperCase(), request.getEndpoint(), request.getParameters(),
body::get); body::get);
requestBodySpec.contentType(MediaType.valueOf(request.getEntity().getContentType().getValue())); requestBodySpec.contentType(MediaType.valueOf(request.getEntity().getContentType().getValue()))
requestBodySpec.body(Mono.fromSupplier(body), String.class); .body(Mono.fromSupplier(body), String.class);
} else { } else {
ClientLogger.logRequest(logId, request.getMethod().toUpperCase(), request.getEndpoint(), request.getParameters()); ClientLogger.logRequest(logId, request.getMethod().toUpperCase(), request.getEndpoint(), request.getParameters());
} }

View File

@ -48,6 +48,7 @@ class DefaultWebClientProvider implements WebClientProvider {
private final HttpHeaders headers; private final HttpHeaders headers;
private final @Nullable String pathPrefix; private final @Nullable String pathPrefix;
private final Function<WebClient, WebClient> webClientConfigurer; private final Function<WebClient, WebClient> webClientConfigurer;
private final Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer;
/** /**
* Create new {@link DefaultWebClientProvider} with empty {@link HttpHeaders} and no-op {@literal error listener}. * Create new {@link DefaultWebClientProvider} with empty {@link HttpHeaders} and no-op {@literal error listener}.
@ -56,7 +57,7 @@ class DefaultWebClientProvider implements WebClientProvider {
* @param connector can be {@literal null}. * @param connector can be {@literal null}.
*/ */
DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector connector) { DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector connector) {
this(scheme, connector, e -> {}, HttpHeaders.EMPTY, null, Function.identity()); this(scheme, connector, e -> {}, HttpHeaders.EMPTY, null, Function.identity(), requestHeadersSpec -> {});
} }
/** /**
@ -66,18 +67,21 @@ class DefaultWebClientProvider implements WebClientProvider {
* @param connector can be {@literal null}. * @param connector can be {@literal null}.
* @param errorListener must not be {@literal null}. * @param errorListener must not be {@literal null}.
* @param headers must not be {@literal null}. * @param headers must not be {@literal null}.
* @param pathPrefix can be {@literal null} * @param pathPrefix can be {@literal null}.
* @param webClientConfigurer must not be {@literal null}. * @param webClientConfigurer must not be {@literal null}.
* @param requestConfigurer must not be {@literal null}.
*/ */
private DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector connector, private DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector connector,
Consumer<Throwable> errorListener, HttpHeaders headers, @Nullable String pathPrefix, Consumer<Throwable> errorListener, HttpHeaders headers, @Nullable String pathPrefix,
Function<WebClient, WebClient> webClientConfigurer) { Function<WebClient, WebClient> webClientConfigurer, Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer) {
Assert.notNull(scheme, "Scheme must not be null! A common scheme would be 'http'."); Assert.notNull(scheme, "Scheme must not be null! A common scheme would be 'http'.");
Assert.notNull(errorListener, "errorListener must not be null! You may want use a no-op one 'e -> {}' instead."); Assert.notNull(errorListener, "errorListener must not be null! You may want use a no-op one 'e -> {}' instead.");
Assert.notNull(headers, "headers must not be null! Think about using 'HttpHeaders.EMPTY' as an alternative."); Assert.notNull(headers, "headers must not be null! Think about using 'HttpHeaders.EMPTY' as an alternative.");
Assert.notNull(webClientConfigurer, Assert.notNull(webClientConfigurer,
"webClientConfigurer must not be null! You may want use a no-op one 'Function.identity()' instead."); "webClientConfigurer must not be null! You may want use a no-op one 'Function.identity()' instead.");
Assert.notNull(requestConfigurer,
"requestConfigurer must not be null! You may want use a no-op one 'r -> {}' instead.\"");
this.cachedClients = new ConcurrentHashMap<>(); this.cachedClients = new ConcurrentHashMap<>();
this.scheme = scheme; this.scheme = scheme;
@ -86,6 +90,7 @@ class DefaultWebClientProvider implements WebClientProvider {
this.headers = headers; this.headers = headers;
this.pathPrefix = pathPrefix; this.pathPrefix = pathPrefix;
this.webClientConfigurer = webClientConfigurer; this.webClientConfigurer = webClientConfigurer;
this.requestConfigurer = requestConfigurer;
} }
@Override @Override
@ -106,6 +111,7 @@ class DefaultWebClientProvider implements WebClientProvider {
return this.errorListener; return this.errorListener;
} }
@Nullable
@Override @Override
public String getPathPrefix() { public String getPathPrefix() {
return pathPrefix; return pathPrefix;
@ -120,7 +126,17 @@ class DefaultWebClientProvider implements WebClientProvider {
merged.addAll(this.headers); merged.addAll(this.headers);
merged.addAll(headers); merged.addAll(headers);
return new DefaultWebClientProvider(scheme, connector, errorListener, merged, pathPrefix, webClientConfigurer); return new DefaultWebClientProvider(scheme, connector, errorListener, merged, pathPrefix, webClientConfigurer,
requestConfigurer);
}
@Override
public WebClientProvider withRequestConfigurer(Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer) {
Assert.notNull(requestConfigurer, "requestConfigurer must not be null.");
return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer,
requestConfigurer);
} }
@Override @Override
@ -129,7 +145,8 @@ class DefaultWebClientProvider implements WebClientProvider {
Assert.notNull(errorListener, "Error listener must not be null."); Assert.notNull(errorListener, "Error listener must not be null.");
Consumer<Throwable> listener = this.errorListener.andThen(errorListener); Consumer<Throwable> listener = this.errorListener.andThen(errorListener);
return new DefaultWebClientProvider(scheme, this.connector, listener, headers, pathPrefix, webClientConfigurer); return new DefaultWebClientProvider(scheme, this.connector, listener, headers, pathPrefix, webClientConfigurer,
requestConfigurer);
} }
@Override @Override
@ -137,18 +154,21 @@ class DefaultWebClientProvider implements WebClientProvider {
Assert.notNull(pathPrefix, "pathPrefix must not be null."); Assert.notNull(pathPrefix, "pathPrefix must not be null.");
return new DefaultWebClientProvider(this.scheme, this.connector, this.errorListener, this.headers, pathPrefix, return new DefaultWebClientProvider(this.scheme, this.connector, this.errorListener, this.headers, pathPrefix,
webClientConfigurer); webClientConfigurer, requestConfigurer);
} }
@Override @Override
public WebClientProvider withWebClientConfigurer(Function<WebClient, WebClient> webClientConfigurer) { public WebClientProvider withWebClientConfigurer(Function<WebClient, WebClient> webClientConfigurer) {
return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer); return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer,
requestConfigurer);
} }
protected WebClient createWebClientForSocketAddress(InetSocketAddress socketAddress) { protected WebClient createWebClientForSocketAddress(InetSocketAddress socketAddress) {
Builder builder = WebClient.builder().defaultHeaders(it -> it.addAll(getDefaultHeaders())); Builder builder = WebClient.builder() //
.defaultHeaders(it -> it.addAll(getDefaultHeaders())) //
.defaultRequest(requestConfigurer);
if (connector != null) { if (connector != null) {
builder = builder.clientConnector(connector); builder = builder.clientConnector(connector);

View File

@ -54,9 +54,9 @@ public interface HostProvider<T extends HostProvider<T>> {
Assert.notEmpty(endpoints, "Please provide at least one endpoint to connect to."); Assert.notEmpty(endpoints, "Please provide at least one endpoint to connect to.");
if (endpoints.length == 1) { if (endpoints.length == 1) {
return new SingleNodeHostProvider(clientProvider, headersSupplier, endpoints[0]); return new SingleNodeHostProvider(clientProvider, endpoints[0]);
} else { } else {
return new MultiNodeHostProvider(clientProvider, headersSupplier, endpoints); return new MultiNodeHostProvider(clientProvider, endpoints);
} }
} }

View File

@ -28,14 +28,12 @@ import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.data.elasticsearch.client.ElasticsearchHost; import org.springframework.data.elasticsearch.client.ElasticsearchHost;
import org.springframework.data.elasticsearch.client.ElasticsearchHost.State; import org.springframework.data.elasticsearch.client.ElasticsearchHost.State;
import org.springframework.data.elasticsearch.client.NoReachableHostException; import org.springframework.data.elasticsearch.client.NoReachableHostException;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
@ -53,14 +51,11 @@ class MultiNodeHostProvider implements HostProvider<MultiNodeHostProvider> {
private final static Logger LOG = LoggerFactory.getLogger(MultiNodeHostProvider.class); private final static Logger LOG = LoggerFactory.getLogger(MultiNodeHostProvider.class);
private final WebClientProvider clientProvider; private final WebClientProvider clientProvider;
private final Supplier<HttpHeaders> headersSupplier;
private final Map<InetSocketAddress, ElasticsearchHost> hosts; private final Map<InetSocketAddress, ElasticsearchHost> hosts;
MultiNodeHostProvider(WebClientProvider clientProvider, Supplier<HttpHeaders> headersSupplier, MultiNodeHostProvider(WebClientProvider clientProvider, InetSocketAddress... endpoints) {
InetSocketAddress... endpoints) {
this.clientProvider = clientProvider; this.clientProvider = clientProvider;
this.headersSupplier = headersSupplier;
this.hosts = new ConcurrentHashMap<>(); this.hosts = new ConcurrentHashMap<>();
for (InetSocketAddress endpoint : endpoints) { for (InetSocketAddress endpoint : endpoints) {
this.hosts.put(endpoint, new ElasticsearchHost(endpoint, State.UNKNOWN)); this.hosts.put(endpoint, new ElasticsearchHost(endpoint, State.UNKNOWN));
@ -166,7 +161,6 @@ class MultiNodeHostProvider implements HostProvider<MultiNodeHostProvider> {
Mono<ClientResponse> clientResponseMono = createWebClient(host) // Mono<ClientResponse> clientResponseMono = createWebClient(host) //
.head().uri("/") // .head().uri("/") //
.headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) //
.exchangeToMono(Mono::just) // .exchangeToMono(Mono::just) //
.timeout(Duration.ofSeconds(1)) // .timeout(Duration.ofSeconds(1)) //
.doOnError(throwable -> { .doOnError(throwable -> {

View File

@ -19,12 +19,10 @@ import reactor.core.publisher.Mono;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.Collections; import java.util.Collections;
import java.util.function.Supplier;
import org.springframework.data.elasticsearch.client.ElasticsearchHost; import org.springframework.data.elasticsearch.client.ElasticsearchHost;
import org.springframework.data.elasticsearch.client.ElasticsearchHost.State; import org.springframework.data.elasticsearch.client.ElasticsearchHost.State;
import org.springframework.data.elasticsearch.client.NoReachableHostException; import org.springframework.data.elasticsearch.client.NoReachableHostException;
import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
/** /**
@ -38,15 +36,12 @@ import org.springframework.web.reactive.function.client.WebClient;
class SingleNodeHostProvider implements HostProvider<SingleNodeHostProvider> { class SingleNodeHostProvider implements HostProvider<SingleNodeHostProvider> {
private final WebClientProvider clientProvider; private final WebClientProvider clientProvider;
private final Supplier<HttpHeaders> headersSupplier;
private final InetSocketAddress endpoint; private final InetSocketAddress endpoint;
private volatile ElasticsearchHost state; private volatile ElasticsearchHost state;
SingleNodeHostProvider(WebClientProvider clientProvider, Supplier<HttpHeaders> headersSupplier, SingleNodeHostProvider(WebClientProvider clientProvider, InetSocketAddress endpoint) {
InetSocketAddress endpoint) {
this.clientProvider = clientProvider; this.clientProvider = clientProvider;
this.headersSupplier = headersSupplier;
this.endpoint = endpoint; this.endpoint = endpoint;
this.state = new ElasticsearchHost(this.endpoint, State.UNKNOWN); this.state = new ElasticsearchHost(this.endpoint, State.UNKNOWN);
} }
@ -60,7 +55,6 @@ class SingleNodeHostProvider implements HostProvider<SingleNodeHostProvider> {
return createWebClient(endpoint) // return createWebClient(endpoint) //
.head().uri("/") // .head().uri("/") //
.headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) //
.exchangeToMono(it -> { .exchangeToMono(it -> {
if (it.statusCode().isError()) { if (it.statusCode().isError()) {
state = ElasticsearchHost.offline(endpoint); state = ElasticsearchHost.offline(endpoint);

View File

@ -101,7 +101,7 @@ public interface WebClientProvider {
/** /**
* Obtain the {@link String pathPrefix} to be used. * Obtain the {@link String pathPrefix} to be used.
* *
* @return the pathPrefix if set. * @return the pathPrefix if set.
* @since 4.0 * @since 4.0
*/ */
@ -126,7 +126,7 @@ public interface WebClientProvider {
/** /**
* Create a new instance of {@link WebClientProvider} where HTTP requests are called with the given path prefix. * Create a new instance of {@link WebClientProvider} where HTTP requests are called with the given path prefix.
* *
* @param pathPrefix Path prefix to add to requests * @param pathPrefix Path prefix to add to requests
* @return new instance of {@link WebClientProvider} * @return new instance of {@link WebClientProvider}
* @since 4.0 * @since 4.0
@ -136,10 +136,20 @@ public interface WebClientProvider {
/** /**
* Create a new instance of {@link WebClientProvider} calling the given {@link Function} to configure the * Create a new instance of {@link WebClientProvider} calling the given {@link Function} to configure the
* {@link WebClient}. * {@link WebClient}.
* *
* @param webClientConfigurer configuration function * @param webClientConfigurer configuration function
* @return new instance of {@link WebClientProvider} * @return new instance of {@link WebClientProvider}
* @since 4.0 * @since 4.0
*/ */
WebClientProvider withWebClientConfigurer(Function<WebClient, WebClient> webClientConfigurer); WebClientProvider withWebClientConfigurer(Function<WebClient, WebClient> webClientConfigurer);
/**
* Create a new instance of {@link WebClientProvider} calling the given {@link Consumer} to configure the requests of
* this {@link WebClient}.
*
* @param requestConfigurer request configuration callback
* @return new instance of {@link WebClientProvider}
* @since 4.3
*/
WebClientProvider withRequestConfigurer(Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer);
} }

View File

@ -83,10 +83,10 @@ public class ReactiveMockClientTestsUtils {
if (hosts.length == 1) { if (hosts.length == 1) {
// noinspection unchecked // noinspection unchecked
delegate = (T) new SingleNodeHostProvider(clientProvider, HttpHeaders::new, getInetSocketAddress(hosts[0])) {}; delegate = (T) new SingleNodeHostProvider(clientProvider, getInetSocketAddress(hosts[0])) {};
} else { } else {
// noinspection unchecked // noinspection unchecked
delegate = (T) new MultiNodeHostProvider(clientProvider, HttpHeaders::new, Arrays.stream(hosts) delegate = (T) new MultiNodeHostProvider(clientProvider, Arrays.stream(hosts)
.map(ReactiveMockClientTestsUtils::getInetSocketAddress).toArray(InetSocketAddress[]::new)) {}; .map(ReactiveMockClientTestsUtils::getInetSocketAddress).toArray(InetSocketAddress[]::new)) {};
} }
@ -297,6 +297,11 @@ public class ReactiveMockClientTestsUtils {
throw new UnsupportedOperationException("not implemented"); throw new UnsupportedOperationException("not implemented");
} }
@Override
public WebClientProvider withRequestConfigurer(Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer) {
throw new UnsupportedOperationException("not implemented");
}
public Send when(String host) { public Send when(String host) {
InetSocketAddress inetSocketAddress = getInetSocketAddress(host); InetSocketAddress inetSocketAddress = getInetSocketAddress(host);
return new CallbackImpl(get(host), headersUriSpecMap.get(inetSocketAddress), return new CallbackImpl(get(host), headersUriSpecMap.get(inetSocketAddress),