Improve client configuration callbacks.

Original Pull Request #1930 
Closes #1929
This commit is contained in:
Peter-Josef Meisch 2021-09-10 17:20:15 +02:00 committed by GitHub
parent 7c35e5327e
commit 8ab84fcc7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 79 additions and 40 deletions

View File

@ -166,15 +166,15 @@ ClientConfiguration clientConfiguration = ClientConfiguration.builder()
return headers; return headers;
}) })
.withClientConfigurer( <.> .withClientConfigurer( <.>
(ReactiveRestClients.WebClientConfigurationCallback) webClient -> { ReactiveRestClients.WebClientConfigurationCallback.from(webClient -> {
// ... // ...
return webClient; return webClient;
}) }))
.withClientConfigurer( <.> .withClientConfigurer( <.>
(RestClients.RestClientConfigurationCallback) clientBuilder -> { RestClients.RestClientConfigurationCallback.from(clientBuilder -> {
// ... // ...
return clientBuilder; return clientBuilder;
}) }))
. // ... other options . // ... other options
.build(); .build();

View File

@ -176,16 +176,16 @@ public interface ClientConfiguration {
/** /**
* @return the Rest Client configuration callback. * @return the Rest Client configuration callback.
* @since 4.2 * @since 4.2
* @deprecated since 4.3 use {@link #getClientConfigurer()} * @deprecated since 4.3 use {@link #getClientConfigurers()}
*/ */
@Deprecated @Deprecated
HttpClientConfigCallback getHttpClientConfigurer(); HttpClientConfigCallback getHttpClientConfigurer();
/** /**
* @return the client configuration callback * @return the client configuration callbacks
* @since 4.3 * @since 4.3
*/ */
<T> ClientConfigurationCallback<T> getClientConfigurer(); <T> List<ClientConfigurationCallback<?>> getClientConfigurers();
/** /**
* @return the supplier for custom headers. * @return the supplier for custom headers.

View File

@ -64,7 +64,7 @@ class ClientConfigurationBuilder
private Function<WebClient, WebClient> webClientConfigurer = Function.identity(); private Function<WebClient, WebClient> webClientConfigurer = Function.identity();
private Supplier<HttpHeaders> headersSupplier = () -> HttpHeaders.EMPTY; private Supplier<HttpHeaders> headersSupplier = () -> HttpHeaders.EMPTY;
@Deprecated private HttpClientConfigCallback httpClientConfigurer = httpClientBuilder -> httpClientBuilder; @Deprecated private HttpClientConfigCallback httpClientConfigurer = httpClientBuilder -> httpClientBuilder;
private ClientConfiguration.ClientConfigurationCallback<?> clientConfigurer = t -> t; private List<ClientConfiguration.ClientConfigurationCallback<?>> clientConfigurers = new ArrayList<>();
/* /*
* (non-Javadoc) * (non-Javadoc)
@ -208,9 +208,7 @@ class ClientConfigurationBuilder
Assert.notNull(webClientConfigurer, "webClientConfigurer must not be null"); Assert.notNull(webClientConfigurer, "webClientConfigurer must not be null");
this.webClientConfigurer = webClientConfigurer; this.webClientConfigurer = webClientConfigurer;
// noinspection NullableProblems this.clientConfigurers.add(ReactiveRestClients.WebClientConfigurationCallback.from(webClientConfigurer));
this.clientConfigurer = (ReactiveRestClients.WebClientConfigurationCallback) webClientConfigurer::apply;
return this; return this;
} }
@ -220,9 +218,8 @@ class ClientConfigurationBuilder
Assert.notNull(httpClientConfigurer, "httpClientConfigurer must not be null"); Assert.notNull(httpClientConfigurer, "httpClientConfigurer must not be null");
this.httpClientConfigurer = httpClientConfigurer; this.httpClientConfigurer = httpClientConfigurer;
// noinspection NullableProblems this.clientConfigurers
this.clientConfigurer = (RestClients.RestClientConfigurationCallback) httpClientConfigurer::customizeHttpClient; .add(RestClients.RestClientConfigurationCallback.from(httpClientConfigurer::customizeHttpClient));
return this; return this;
} }
@ -232,7 +229,7 @@ class ClientConfigurationBuilder
Assert.notNull(clientConfigurer, "clientConfigurer must not be null"); Assert.notNull(clientConfigurer, "clientConfigurer must not be null");
this.clientConfigurer = clientConfigurer; this.clientConfigurers.add(clientConfigurer);
return this; return this;
} }
@ -260,7 +257,7 @@ class ClientConfigurationBuilder
} }
return new DefaultClientConfiguration(hosts, headers, useSsl, sslContext, soTimeout, connectTimeout, pathPrefix, return new DefaultClientConfiguration(hosts, headers, useSsl, sslContext, soTimeout, connectTimeout, pathPrefix,
hostnameVerifier, proxy, webClientConfigurer, httpClientConfigurer, clientConfigurer, headersSupplier); hostnameVerifier, proxy, webClientConfigurer, httpClientConfigurer, clientConfigurers, headersSupplier);
} }
private static InetSocketAddress parse(String hostAndPort) { private static InetSocketAddress parse(String hostAndPort) {

View File

@ -55,13 +55,13 @@ class DefaultClientConfiguration implements ClientConfiguration {
private final Function<WebClient, WebClient> webClientConfigurer; private final Function<WebClient, WebClient> webClientConfigurer;
private final HttpClientConfigCallback httpClientConfigurer; private final HttpClientConfigCallback httpClientConfigurer;
private final Supplier<HttpHeaders> headersSupplier; private final Supplier<HttpHeaders> headersSupplier;
private final ClientConfigurationCallback<?> clientConfigurer; private final List<ClientConfigurationCallback<?>> clientConfigurers;
DefaultClientConfiguration(List<InetSocketAddress> hosts, HttpHeaders headers, boolean useSsl, DefaultClientConfiguration(List<InetSocketAddress> hosts, HttpHeaders headers, boolean useSsl,
@Nullable SSLContext sslContext, Duration soTimeout, Duration connectTimeout, @Nullable String pathPrefix, @Nullable SSLContext sslContext, Duration soTimeout, Duration connectTimeout, @Nullable String pathPrefix,
@Nullable HostnameVerifier hostnameVerifier, @Nullable String proxy, @Nullable HostnameVerifier hostnameVerifier, @Nullable String proxy,
Function<WebClient, WebClient> webClientConfigurer, HttpClientConfigCallback httpClientConfigurer, Function<WebClient, WebClient> webClientConfigurer, HttpClientConfigCallback httpClientConfigurer,
ClientConfigurationCallback<?> clientConfigurer, Supplier<HttpHeaders> headersSupplier) { List<ClientConfigurationCallback<?>> clientConfigurers, Supplier<HttpHeaders> headersSupplier) {
this.hosts = Collections.unmodifiableList(new ArrayList<>(hosts)); this.hosts = Collections.unmodifiableList(new ArrayList<>(hosts));
this.headers = new HttpHeaders(headers); this.headers = new HttpHeaders(headers);
@ -74,7 +74,7 @@ class DefaultClientConfiguration implements ClientConfiguration {
this.proxy = proxy; this.proxy = proxy;
this.webClientConfigurer = webClientConfigurer; this.webClientConfigurer = webClientConfigurer;
this.httpClientConfigurer = httpClientConfigurer; this.httpClientConfigurer = httpClientConfigurer;
this.clientConfigurer = clientConfigurer; this.clientConfigurers = clientConfigurers;
this.headersSupplier = headersSupplier; this.headersSupplier = headersSupplier;
} }
@ -136,8 +136,8 @@ class DefaultClientConfiguration implements ClientConfiguration {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override @Override
public <T> ClientConfigurationCallback<T> getClientConfigurer() { public <T> List<ClientConfigurationCallback<?>> getClientConfigurers() {
return (ClientConfigurationCallback<T>) clientConfigurer; return clientConfigurers;
} }
@Override @Override

View File

@ -22,6 +22,7 @@ import java.net.InetSocketAddress;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -120,7 +121,13 @@ public final class RestClients {
clientConfiguration.getProxy().map(HttpHost::create).ifPresent(clientBuilder::setProxy); clientConfiguration.getProxy().map(HttpHost::create).ifPresent(clientBuilder::setProxy);
clientBuilder = clientConfiguration.<HttpAsyncClientBuilder> getClientConfigurer().configure(clientBuilder); for (ClientConfiguration.ClientConfigurationCallback<?> clientConfigurer : clientConfiguration
.getClientConfigurers()) {
if (clientConfigurer instanceof RestClientConfigurationCallback) {
RestClientConfigurationCallback restClientConfigurationCallback = (RestClientConfigurationCallback) clientConfigurer;
clientBuilder = restClientConfigurationCallback.configure(clientBuilder);
}
}
return clientBuilder; return clientBuilder;
}); });
@ -242,5 +249,15 @@ public final class RestClients {
* @since 4.3 * @since 4.3
*/ */
public interface RestClientConfigurationCallback public interface RestClientConfigurationCallback
extends ClientConfiguration.ClientConfigurationCallback<HttpAsyncClientBuilder> {} extends ClientConfiguration.ClientConfigurationCallback<HttpAsyncClientBuilder> {
static RestClientConfigurationCallback from(
Function<HttpAsyncClientBuilder, HttpAsyncClientBuilder> clientBuilderCallback) {
Assert.notNull(clientBuilderCallback, "clientBuilderCallback must not be null");
// noinspection NullableProblems
return clientBuilderCallback::apply;
}
}
} }

View File

@ -288,9 +288,21 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch
provider = provider.withPathPrefix(clientConfiguration.getPathPrefix()); provider = provider.withPathPrefix(clientConfiguration.getPathPrefix());
} }
Function<WebClient, WebClient> webClientConfigurer = webClient -> {
for (ClientConfiguration.ClientConfigurationCallback<?> clientConfigurer : clientConfiguration
.getClientConfigurers()) {
if (clientConfigurer instanceof ReactiveRestClients.WebClientConfigurationCallback) {
ReactiveRestClients.WebClientConfigurationCallback webClientConfigurationCallback = (ReactiveRestClients.WebClientConfigurationCallback) clientConfigurer;
webClient = webClientConfigurationCallback.configure(webClient);
}
}
return webClient;
};
provider = provider // provider = provider //
.withDefaultHeaders(clientConfiguration.getDefaultHeaders()) // .withDefaultHeaders(clientConfiguration.getDefaultHeaders()) //
.withWebClientConfigurer(clientConfiguration.<WebClient> getClientConfigurer()::configure) // .withWebClientConfigurer(webClientConfigurer) //
.withRequestConfigurer(requestHeadersSpec -> requestHeadersSpec.headers(httpHeaders -> { .withRequestConfigurer(requestHeadersSpec -> requestHeadersSpec.headers(httpHeaders -> {
HttpHeaders suppliedHeaders = clientConfiguration.getHeadersSupplier().get(); HttpHeaders suppliedHeaders = clientConfiguration.getHeadersSupplier().get();

View File

@ -15,6 +15,8 @@
*/ */
package org.springframework.data.elasticsearch.client.reactive; package org.springframework.data.elasticsearch.client.reactive;
import java.util.function.Function;
import org.springframework.data.elasticsearch.client.ClientConfiguration; import org.springframework.data.elasticsearch.client.ClientConfiguration;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
@ -69,5 +71,14 @@ public final class ReactiveRestClients {
* *
* @since 4.3 * @since 4.3
*/ */
public interface WebClientConfigurationCallback extends ClientConfiguration.ClientConfigurationCallback<WebClient> {} public interface WebClientConfigurationCallback extends ClientConfiguration.ClientConfigurationCallback<WebClient> {
static WebClientConfigurationCallback from(Function<WebClient, WebClient> webClientCallback) {
Assert.notNull(webClientCallback, "webClientCallback must not be null");
// noinspection NullableProblems
return webClientCallback::apply;
}
}
} }

View File

@ -30,6 +30,7 @@ import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.data.elasticsearch.client.reactive.ReactiveRestClients;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
@ -172,10 +173,10 @@ public class ClientConfigurationUnitTests {
}) // }) //
.build(); .build();
ClientConfiguration.ClientConfigurationCallback<HttpAsyncClientBuilder> clientConfigurer = clientConfiguration ClientConfiguration.ClientConfigurationCallback<?> clientConfigurer = clientConfiguration.getClientConfigurers()
.getClientConfigurer(); .get(0);
clientConfigurer.configure(HttpAsyncClientBuilder.create()); ((RestClients.RestClientConfigurationCallback) clientConfigurer).configure(HttpAsyncClientBuilder.create());
assertThat(callCounter.get()).isEqualTo(1); assertThat(callCounter.get()).isEqualTo(1);
} }
@ -193,10 +194,10 @@ public class ClientConfigurationUnitTests {
}) // }) //
.build(); .build();
ClientConfiguration.ClientConfigurationCallback<WebClient> clientConfigurer = clientConfiguration ClientConfiguration.ClientConfigurationCallback<?> clientConfigurer = clientConfiguration.getClientConfigurers()
.getClientConfigurer(); .get(0);
clientConfigurer.configure(WebClient.builder().build()); ((ReactiveRestClients.WebClientConfigurationCallback) clientConfigurer).configure(WebClient.builder().build());
assertThat(callCounter.get()).isEqualTo(1); assertThat(callCounter.get()).isEqualTo(1);
} }
@ -214,10 +215,10 @@ public class ClientConfigurationUnitTests {
}) // }) //
.build(); .build();
ClientConfiguration.ClientConfigurationCallback<Object> clientConfigurer = clientConfiguration ClientConfiguration.ClientConfigurationCallback<?> clientConfigurer = clientConfiguration.getClientConfigurers()
.getClientConfigurer(); .get(0);
clientConfigurer.configure(new Object()); ((ClientConfiguration.ClientConfigurationCallback<Object>) clientConfigurer).configure(new Object());
assertThat(callCounter.get()).isEqualTo(1); assertThat(callCounter.get()).isEqualTo(1);
} }

View File

@ -118,15 +118,16 @@ public class RestClientsTest {
}); });
if (clientUnderTestFactory instanceof RestClientUnderTestFactory) { if (clientUnderTestFactory instanceof RestClientUnderTestFactory) {
configurationBuilder.withClientConfigurer((RestClients.RestClientConfigurationCallback) httpClientBuilder -> { configurationBuilder
clientConfigurerCount.incrementAndGet(); .withClientConfigurer(RestClients.RestClientConfigurationCallback.from(httpClientBuilder -> {
return httpClientBuilder; clientConfigurerCount.incrementAndGet();
}); return httpClientBuilder;
}));
} else if (clientUnderTestFactory instanceof ReactiveElasticsearchClientUnderTestFactory) { } else if (clientUnderTestFactory instanceof ReactiveElasticsearchClientUnderTestFactory) {
configurationBuilder.withClientConfigurer((ReactiveRestClients.WebClientConfigurationCallback) webClient -> { configurationBuilder.withClientConfigurer(ReactiveRestClients.WebClientConfigurationCallback.from(webClient -> {
clientConfigurerCount.incrementAndGet(); clientConfigurerCount.incrementAndGet();
return webClient; return webClient;
}); }));
} }
ClientConfiguration clientConfiguration = configurationBuilder.build(); ClientConfiguration clientConfiguration = configurationBuilder.build();