DATAES-801 - Implement callback to enable adding custom headers in the REST HTTP request.

Original PR: #442
This commit is contained in:
Peter-Josef Meisch 2020-04-26 17:30:46 +02:00 committed by GitHub
parent 65f89f9480
commit a4ec819e7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 347 additions and 70 deletions

View File

@ -252,7 +252,7 @@
<dependency>
<groupId>com.github.tomakehurst</groupId>
<artifactId>wiremock-jre8</artifactId>
<version>2.25.1</version>
<version>2.26.3</version>
<scope>test</scope>
<exclusions>
<!-- these exclusions are needed because of Elasticsearch JarHell-->

View File

@ -143,40 +143,47 @@ NOTE: The ReactiveClient response, especially for search operations, is bound to
[[elasticsearch.clients.configuration]]
== Client Configuration
Client behaviour can be changed via the `ClientConfiguration` that allows to set options for SSL, connect and socket timeouts.
Client behaviour can be changed via the `ClientConfiguration` that allows to set options for SSL, connect and socket timeouts, headers and other parameters.
.Client Configuration
====
[source,java]
----
// optional if Basic Auhtentication is needed
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.add("es-security-runas-user", "some-user") <1>
httpHeaders.add("some-header", "on every request") <1>
ClientConfiguration clientConfiguration = ClientConfiguration.builder()
.connectedTo("localhost:9200", "localhost:9291") <2>
.withProxy("localhost:8888") <3>
.withPathPrefix("ela") <4>
.withConnectTimeout(Duration.ofSeconds(5)) <5>
.withSocketTimeout(Duration.ofSeconds(3)) <6>
.useSsl() <7>
.useSsl() <3>
.withProxy("localhost:8888") <4>
.withPathPrefix("ela") <5>
.withConnectTimeout(Duration.ofSeconds(5)) <6>
.withSocketTimeout(Duration.ofSeconds(3)) <7>
.withDefaultHeaders(defaultHeaders) <8>
.withBasicAuth(username, password) <9>
.withHeaders(() -> { <10>
HttpHeaders headers = new HttpHeaders();
headers.add("currentTime", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
return headers;
})
. // ... other options
.build();
----
<1> Define default headers, if they need to be customized
<2> Use the builder to provide cluster addresses, set default `HttpHeaders` or enable SSL.
<3> Optionally set a proxy footnote:notreactive[not yet implemented for the reactive client].
<4> Optionally set a path prefix, mostly used when different clusters a behind some reverse proxy.
<5> Set the connection timeout. Default is 10 sec.
<6> Set the socket timeout. Default is 5 sec.
<7> Optionally enable SSL.
<3> Optionally enable SSL.
<4> Optionally set a proxy.
<5> Optionally set a path prefix, mostly used when different clusters a behind some reverse proxy.
<6> Set the connection timeout. Default is 10 sec.
<7> Set the socket timeout. Default is 5 sec.
<8> Optionally set headers.
<9> Add basic authentication.
<10> A `Supplier<Header>` function can be specified which is called every time before a request is sent to Elasticsearch - here, as an example, the current time is written in a header.
====
IMPORTANT: Adding a Header supplier as shown in above example allows to inject headers that may change over the time, like authentication JWT tokens. If this is used in the reactive setup, the supplier function *must not* block!
[[elasticsearch.clients.logging]]
== Client Logging

View File

@ -21,6 +21,7 @@ import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
@ -170,6 +171,11 @@ public interface ClientConfiguration {
*/
Function<WebClient, WebClient> getWebClientConfigurer();
/**
* @return the supplier for custom headers.
*/
Supplier<HttpHeaders> getHeadersSupplier();
/**
* @author Christoph Strobl
*/
@ -335,6 +341,19 @@ public interface ClientConfiguration {
*/
TerminalClientConfigurationBuilder withWebClientConfigurer(Function<WebClient, WebClient> webClientConfigurer);
/**
* set a supplier for custom headers. This is invoked for every HTTP request to Elasticsearch to retrieve headers
* that should be sent with the request. A common use case is passing in authentication headers that may change.
* <br/>
* Note: When used in a reactive environment, the calling of {@link Supplier#get()} function must not do any
* blocking operations. It may return {@literal null}.
*
* @param headers supplier function for headers, must not be {@literal null}
* @return the {@link TerminalClientConfigurationBuilder}.
* @since 4.0
*/
TerminalClientConfigurationBuilder withHeaders(Supplier<HttpHeaders> headers);
/**
* Build the {@link ClientConfiguration} object.
*

View File

@ -21,6 +21,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.net.ssl.HostnameVerifier;
@ -58,7 +59,8 @@ class ClientConfigurationBuilder
private @Nullable String password;
private @Nullable String pathPrefix;
private @Nullable String proxy;
private @Nullable Function<WebClient, WebClient> webClientConfigurer;
private Function<WebClient, WebClient> webClientConfigurer = Function.identity();
private Supplier<HttpHeaders> headersSupplier = () -> HttpHeaders.EMPTY;
/*
* (non-Javadoc)
@ -196,7 +198,8 @@ class ClientConfigurationBuilder
}
@Override
public TerminalClientConfigurationBuilder withWebClientConfigurer(Function<WebClient, WebClient> webClientConfigurer) {
public TerminalClientConfigurationBuilder withWebClientConfigurer(
Function<WebClient, WebClient> webClientConfigurer) {
Assert.notNull(webClientConfigurer, "webClientConfigurer must not be null");
@ -204,6 +207,15 @@ class ClientConfigurationBuilder
return this;
}
@Override
public TerminalClientConfigurationBuilder withHeaders(Supplier<HttpHeaders> headers) {
Assert.notNull(headers, "headersSupplier must not be null");
this.headersSupplier = headers;
return this;
}
/*
* (non-Javadoc)
* @see org.springframework.data.elasticsearch.client.ClientConfiguration.ClientConfigurationBuilderWithOptionalDefaultHeaders#build()
@ -219,7 +231,7 @@ class ClientConfigurationBuilder
}
return new DefaultClientConfiguration(hosts, headers, useSsl, sslContext, soTimeout, connectTimeout, pathPrefix,
hostnameVerifier, proxy, webClientConfigurer);
hostnameVerifier, proxy, webClientConfigurer, headersSupplier);
}
private static InetSocketAddress parse(String hostAndPort) {

View File

@ -20,6 +20,7 @@ import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.lang.Nullable;
import org.springframework.util.ObjectUtils;
/**
@ -90,7 +91,7 @@ public abstract class ClientLogger {
* @param logId the correlation Id, see {@link #newLogId()}.
* @param statusCode the HTTP status code.
*/
public static void logRawResponse(String logId, HttpStatus statusCode) {
public static void logRawResponse(String logId, @Nullable HttpStatus statusCode) {
if (isEnabled()) {
WIRE_LOGGER.trace("[{}] Received raw response: {}", logId, statusCode);

View File

@ -22,6 +22,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
@ -50,12 +51,13 @@ class DefaultClientConfiguration implements ClientConfiguration {
private final @Nullable String pathPrefix;
private final @Nullable HostnameVerifier hostnameVerifier;
private final @Nullable String proxy;
private final @Nullable Function<WebClient, WebClient> webClientConfigurer;
private final Function<WebClient, WebClient> webClientConfigurer;
private final Supplier<HttpHeaders> headersSupplier;
DefaultClientConfiguration(List<InetSocketAddress> hosts, HttpHeaders headers, boolean useSsl,
@Nullable SSLContext sslContext, Duration soTimeout, Duration connectTimeout, @Nullable String pathPrefix,
@Nullable HostnameVerifier hostnameVerifier, @Nullable String proxy,
@Nullable Function<WebClient, WebClient> webClientConfigurer) {
Function<WebClient, WebClient> webClientConfigurer, Supplier<HttpHeaders> headersSupplier) {
this.hosts = Collections.unmodifiableList(new ArrayList<>(hosts));
this.headers = new HttpHeaders(headers);
@ -67,6 +69,7 @@ class DefaultClientConfiguration implements ClientConfiguration {
this.hostnameVerifier = hostnameVerifier;
this.proxy = proxy;
this.webClientConfigurer = webClientConfigurer;
this.headersSupplier = headersSupplier;
}
@Override
@ -117,6 +120,11 @@ class DefaultClientConfiguration implements ClientConfiguration {
@Override
public Function<WebClient, WebClient> getWebClientConfigurer() {
return webClientConfigurer != null ? webClientConfigurer : Function.identity();
return webClientConfigurer;
}
@Override
public Supplier<HttpHeaders> getHeadersSupplier() {
return headersSupplier;
}
}

View File

@ -20,13 +20,11 @@ import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.HttpEntityEnclosingRequest;
@ -87,31 +85,23 @@ public final class RestClients {
HttpHeaders headers = clientConfiguration.getDefaultHeaders();
if (!headers.isEmpty()) {
Header[] httpHeaders = headers.toSingleValueMap().entrySet().stream()
.map(it -> new BasicHeader(it.getKey(), it.getValue())).toArray(Header[]::new);
builder.setDefaultHeaders(httpHeaders);
builder.setDefaultHeaders(toHeaderArray(headers));
}
builder.setHttpClientConfigCallback(clientBuilder -> {
Optional<SSLContext> sslContext = clientConfiguration.getSslContext();
Optional<HostnameVerifier> hostNameVerifier = clientConfiguration.getHostNameVerifier();
sslContext.ifPresent(clientBuilder::setSSLContext);
hostNameVerifier.ifPresent(clientBuilder::setSSLHostnameVerifier);
clientConfiguration.getSslContext().ifPresent(clientBuilder::setSSLContext);
clientConfiguration.getHostNameVerifier().ifPresent(clientBuilder::setSSLHostnameVerifier);
clientBuilder.addInterceptorLast(new CustomHeaderInjector(clientConfiguration.getHeadersSupplier()));
if (ClientLogger.isEnabled()) {
HttpLoggingInterceptor interceptor = new HttpLoggingInterceptor();
clientBuilder.addInterceptorLast((HttpRequestInterceptor) interceptor);
clientBuilder.addInterceptorLast((HttpResponseInterceptor) interceptor);
}
Duration connectTimeout = clientConfiguration.getConnectTimeout();
Duration timeout = clientConfiguration.getSocketTimeout();
Builder requestConfigBuilder = RequestConfig.custom();
Duration connectTimeout = clientConfiguration.getConnectTimeout();
if (!connectTimeout.isNegative()) {
@ -119,6 +109,8 @@ public final class RestClients {
requestConfigBuilder.setConnectionRequestTimeout(Math.toIntExact(connectTimeout.toMillis()));
}
Duration timeout = clientConfiguration.getSocketTimeout();
if (!timeout.isNegative()) {
requestConfigBuilder.setSocketTimeout(Math.toIntExact(timeout.toMillis()));
}
@ -134,8 +126,16 @@ public final class RestClients {
return () -> client;
}
private static Header[] toHeaderArray(HttpHeaders headers) {
return headers.entrySet().stream() //
.flatMap(entry -> entry.getValue().stream() //
.map(value -> new BasicHeader(entry.getKey(), value))) //
.toArray(Header[]::new);
}
private static List<String> formattedHosts(List<InetSocketAddress> hosts, boolean useSsl) {
return hosts.stream().map(it -> (useSsl ? "https" : "http") + "://" + it.getHostString() + ":" + it.getPort()).collect(Collectors.toList());
return hosts.stream().map(it -> (useSsl ? "https" : "http") + "://" + it.getHostString() + ":" + it.getPort())
.collect(Collectors.toList());
}
/**
@ -180,7 +180,6 @@ public final class RestClients {
String logId = (String) context.getAttribute(RestClients.LOG_ID_ATTRIBUTE);
if (logId == null) {
logId = ClientLogger.newLogId();
context.setAttribute(RestClients.LOG_ID_ATTRIBUTE, logId);
}
@ -205,10 +204,31 @@ public final class RestClients {
@Override
public void process(HttpResponse response, HttpContext context) {
String logId = (String) context.getAttribute(RestClients.LOG_ID_ATTRIBUTE);
ClientLogger.logRawResponse(logId, HttpStatus.resolve(response.getStatusLine().getStatusCode()));
}
}
/**
* Interceptor to inject custom supplied headers.
*
* @since 4.0
*/
private static class CustomHeaderInjector implements HttpRequestInterceptor {
public CustomHeaderInjector(Supplier<HttpHeaders> headersSupplier) {
this.headersSupplier = headersSupplier;
}
private final Supplier<HttpHeaders> headersSupplier;
@Override
public void process(HttpRequest request, HttpContext context) {
HttpHeaders httpHeaders = headersSupplier.get();
if (httpHeaders != null && httpHeaders != HttpHeaders.EMPTY) {
Arrays.stream(toHeaderArray(httpHeaders)).forEach(request::addHeader);
}
}
}
}

View File

@ -44,6 +44,7 @@ import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.net.ssl.SSLContext;
@ -138,8 +139,8 @@ import org.springframework.web.reactive.function.client.WebClient.RequestBodySpe
public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearchClient, Indices {
private final HostProvider hostProvider;
private final RequestCreator requestCreator;
private Supplier<HttpHeaders> headersSupplier = () -> HttpHeaders.EMPTY;
/**
* Create a new {@link DefaultReactiveElasticsearchClient} using the given {@link HostProvider} to obtain server
@ -167,6 +168,13 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch
this.requestCreator = requestCreator;
}
public void setHeadersSupplier(Supplier<HttpHeaders> headersSupplier) {
Assert.notNull(headersSupplier, "headersSupplier must not be null");
this.headersSupplier = headersSupplier;
}
/**
* Create a new {@link DefaultReactiveElasticsearchClient} aware of the given nodes in the cluster. <br />
* <strong>NOTE</strong> If the cluster requires authentication be sure to provide the according {@link HttpHeaders}
@ -216,9 +224,14 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch
WebClientProvider provider = getWebClientProvider(clientConfiguration);
HostProvider hostProvider = HostProvider.provider(provider,
HostProvider hostProvider = HostProvider.provider(provider, clientConfiguration.getHeadersSupplier(),
clientConfiguration.getEndpoints().toArray(new InetSocketAddress[0]));
return new DefaultReactiveElasticsearchClient(hostProvider, requestCreator);
DefaultReactiveElasticsearchClient client = new DefaultReactiveElasticsearchClient(hostProvider, requestCreator);
client.setHeadersSupplier(clientConfiguration.getHeadersSupplier());
return client;
}
private static WebClientProvider getWebClientProvider(ClientConfiguration clientConfiguration) {
@ -698,6 +711,12 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch
request.getOptions().getHeaders().forEach(it -> theHeaders.add(it.getName(), it.getValue()));
}
}
// plus the ones from the supplier
HttpHeaders suppliedHeaders = headersSupplier.get();
if (suppliedHeaders != null && suppliedHeaders != HttpHeaders.EMPTY) {
theHeaders.addAll(suppliedHeaders);
}
});
if (request.getEntity() != null) {

View File

@ -20,9 +20,11 @@ import reactor.core.publisher.Mono;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.Set;
import java.util.function.Supplier;
import org.springframework.data.elasticsearch.client.ElasticsearchHost;
import org.springframework.data.elasticsearch.client.NoReachableHostException;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.client.WebClient;
@ -40,18 +42,20 @@ public interface HostProvider {
* Create a new {@link HostProvider} best suited for the given {@link WebClientProvider} and number of hosts.
*
* @param clientProvider must not be {@literal null} .
* @param headersSupplier to supply custom headers, must not be {@literal null}
* @param endpoints must not be {@literal null} nor empty.
* @return new instance of {@link HostProvider}.
*/
static HostProvider provider(WebClientProvider clientProvider, InetSocketAddress... endpoints) {
static HostProvider provider(WebClientProvider clientProvider, Supplier<HttpHeaders> headersSupplier,
InetSocketAddress... endpoints) {
Assert.notNull(clientProvider, "WebClientProvider must not be null");
Assert.notEmpty(endpoints, "Please provide at least one endpoint to connect to.");
if (endpoints.length == 1) {
return new SingleNodeHostProvider(clientProvider, endpoints[0]);
return new SingleNodeHostProvider(clientProvider, headersSupplier, endpoints[0]);
} else {
return new MultiNodeHostProvider(clientProvider, endpoints);
return new MultiNodeHostProvider(clientProvider,headersSupplier, endpoints);
}
}

View File

@ -15,6 +15,7 @@
*/
package org.springframework.data.elasticsearch.client.reactive;
import org.springframework.http.HttpHeaders;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
@ -27,6 +28,7 @@ import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import org.springframework.data.elasticsearch.client.ElasticsearchHost;
import org.springframework.data.elasticsearch.client.ElasticsearchHost.State;
@ -45,11 +47,13 @@ import org.springframework.web.reactive.function.client.WebClient;
class MultiNodeHostProvider implements HostProvider {
private final WebClientProvider clientProvider;
private final Supplier<HttpHeaders> headersSupplier;
private final Map<InetSocketAddress, ElasticsearchHost> hosts;
MultiNodeHostProvider(WebClientProvider clientProvider, InetSocketAddress... endpoints) {
MultiNodeHostProvider(WebClientProvider clientProvider, Supplier<HttpHeaders> headersSupplier, InetSocketAddress... endpoints) {
this.clientProvider = clientProvider;
this.headersSupplier = headersSupplier;
this.hosts = new ConcurrentHashMap<>();
for (InetSocketAddress endpoint : endpoints) {
this.hosts.put(endpoint, new ElasticsearchHost(endpoint, State.UNKNOWN));
@ -133,8 +137,9 @@ class MultiNodeHostProvider implements HostProvider {
.flatMap(host -> {
Mono<ClientResponse> exchange = createWebClient(host) //
.head().uri("/").exchange().doOnError(throwable -> {
.head().uri("/") //
.headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) //
.exchange().doOnError(throwable -> {
hosts.put(host, new ElasticsearchHost(host, State.OFFLINE));
clientProvider.getErrorListener().accept(throwable);
});

View File

@ -15,10 +15,12 @@
*/
package org.springframework.data.elasticsearch.client.reactive;
import org.springframework.http.HttpHeaders;
import reactor.core.publisher.Mono;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.function.Supplier;
import org.springframework.data.elasticsearch.client.ElasticsearchHost;
import org.springframework.data.elasticsearch.client.ElasticsearchHost.State;
@ -35,12 +37,14 @@ import org.springframework.web.reactive.function.client.WebClient;
class SingleNodeHostProvider implements HostProvider {
private final WebClientProvider clientProvider;
private final Supplier<HttpHeaders> headersSupplier;
private final InetSocketAddress endpoint;
private volatile ElasticsearchHost state;
SingleNodeHostProvider(WebClientProvider clientProvider, InetSocketAddress endpoint) {
SingleNodeHostProvider(WebClientProvider clientProvider, Supplier<HttpHeaders> headersSupplier, InetSocketAddress endpoint) {
this.clientProvider = clientProvider;
this.headersSupplier = headersSupplier;
this.endpoint = endpoint;
this.state = new ElasticsearchHost(this.endpoint, State.UNKNOWN);
}
@ -53,9 +57,10 @@ class SingleNodeHostProvider implements HostProvider {
public Mono<ClusterInformation> clusterInfo() {
return createWebClient(endpoint) //
.head().uri("/").exchange() //
.head().uri("/")
.headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) //
.exchange() //
.flatMap(it -> {
if (it.statusCode().isError()) {
state = ElasticsearchHost.offline(endpoint);
} else {

View File

@ -4,44 +4,221 @@ import static com.github.tomakehurst.wiremock.client.WireMock.*;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestHighLevelClient;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.data.elasticsearch.client.reactive.ReactiveElasticsearchClient;
import org.springframework.data.elasticsearch.client.reactive.ReactiveRestClients;
import org.springframework.http.HttpHeaders;
import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.WireMock;
import com.github.tomakehurst.wiremock.matching.AnythingPattern;
import com.github.tomakehurst.wiremock.matching.EqualToPattern;
/**
* @author Peter-Josef Meisch
*/
public class RestClientsTest {
@Test // DATAES-700
void shouldUseConfiguredProxy() throws IOException {
@ParameterizedTest // DATAES-700
@MethodSource("clientUnderTestFactorySource")
@DisplayName("should use configured proxy")
void shouldUseConfiguredProxy(ClientUnderTestFactory clientUnderTestFactory) throws IOException {
WireMockServer wireMockServer = new WireMockServer(options() //
.dynamicPort() //
.usingFilesUnderDirectory("src/test/resources/wiremock-mappings")); // needed, otherwise Wiremock goes to
// test/resources/mappings
wireMockServer.start();
try {
WireMock.configureFor(wireMockServer.port());
if (clientUnderTestFactory instanceof ReactiveElasticsearchClientUnderTestFactory) {
// although the reactive code is using the proxy for every call - tested with an intercepting
// proxy - somehow in this test wiremock fails to register this. So we skip it here
//
return;
}
wireMockServer(server -> {
WireMock.configureFor(server.port());
stubFor(head(urlEqualTo("/")).willReturn(aResponse() //
.withHeader("Content-Type", "application/json; charset=UTF-8")));
ClientConfigurationBuilder configurationBuilder = new ClientConfigurationBuilder();
ClientConfiguration clientConfiguration = configurationBuilder //
.connectedTo("localhost:9200")//
.withProxy("localhost:" + wireMockServer.port()) //
.connectedTo("localhost:4711")//
.withProxy("localhost:" + server.port()) //
.build();
ClientUnderTest clientUnderTest = clientUnderTestFactory.create(clientConfiguration);
RestHighLevelClient restClient = RestClients.create(clientConfiguration).rest();
restClient.ping(RequestOptions.DEFAULT);
clientUnderTest.ping();
verify(headRequestedFor(urlEqualTo("/")));
});
}
@ParameterizedTest // DATAES-801
@MethodSource("clientUnderTestFactorySource")
@DisplayName("should set all required headers")
void shouldSetAllRequiredHeaders(ClientUnderTestFactory clientUnderTestFactory) {
wireMockServer(server -> {
WireMock.configureFor(server.port());
stubFor(head(urlEqualTo("/")).willReturn(aResponse() //
.withHeader("Content-Type", "application/json; charset=UTF-8")));
HttpHeaders defaultHeaders = new HttpHeaders();
defaultHeaders.addAll("def1", Arrays.asList("def1-1", "def1-2"));
defaultHeaders.add("def2", "def2-1");
AtomicInteger supplierCount = new AtomicInteger(1);
ClientConfigurationBuilder configurationBuilder = new ClientConfigurationBuilder();
ClientConfiguration clientConfiguration = configurationBuilder //
.connectedTo("localhost:" + server.port()) //
.withBasicAuth("user", "password") //
.withDefaultHeaders(defaultHeaders) //
.withHeaders(() -> {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.add("supplied", "val0");
httpHeaders.add("supplied", "val" + supplierCount.getAndIncrement());
return httpHeaders;
}).build();
ClientUnderTest clientUnderTest = clientUnderTestFactory.create(clientConfiguration);
// do several calls to check that the headerSupplier provided values are set
for (int i = 1; i <= 3; i++) {
clientUnderTest.ping();
verify(headRequestedFor(urlEqualTo("/")).withHeader("Authorization", new AnythingPattern()) //
.withHeader("def1", new EqualToPattern("def1-1")) //
.withHeader("def1", new EqualToPattern("def1-2")) //
.withHeader("def2", new EqualToPattern("def2-1")) //
.withHeader("supplied", new EqualToPattern("val0")) //
.withHeader("supplied", new EqualToPattern("val" + i)) //
);
}
});
}
/**
* Consumer extension that catches checked exceptions and wraps them in a RuntimeException.
*/
@FunctionalInterface
interface WiremockConsumer extends Consumer<WireMockServer> {
@Override
default void accept(WireMockServer wiremockConsumer) {
try {
acceptThrows(wiremockConsumer);
} catch (final Exception e) {
throw new RuntimeException(e);
}
}
void acceptThrows(WireMockServer wiremockConsumer) throws Exception;
}
/**
* starts a Wiremock server and calls consumer with the server as argument. Stops the server after consumer execution.
*
* @param consumer the consumer
*/
private void wireMockServer(WiremockConsumer consumer) {
WireMockServer wireMockServer = new WireMockServer(options() //
.dynamicPort() //
.usingFilesUnderDirectory("src/test/resources/wiremock-mappings")); // needed, otherwise Wiremock goes to
// test/resources/mappings
try {
wireMockServer.start();
consumer.accept(wireMockServer);
} finally {
wireMockServer.shutdown();
}
}
/**
* The client to be tested. Abstraction to be able to test reactive and non-reactive clients.
*/
interface ClientUnderTest {
/**
* Pings the configured server.
*
* @return
*/
boolean ping() throws Exception;
}
/**
* base class to create {@link ClientUnderTest} implementations.
*/
static abstract class ClientUnderTestFactory {
abstract ClientUnderTest create(ClientConfiguration clientConfiguration);
@Override
public String toString() {
return getDisplayName();
}
protected abstract String getDisplayName();
}
/**
* {@link ClientUnderTestFactory} implementation for the Standard {@link RestHighLevelClient}.
*/
static class RestClientUnderTestFactory extends ClientUnderTestFactory {
@Override
protected String getDisplayName() {
return "RestHighLevelClient";
}
@Override
ClientUnderTest create(ClientConfiguration clientConfiguration) {
RestHighLevelClient client = RestClients.create(clientConfiguration).rest();
return new ClientUnderTest() {
@Override
public boolean ping() throws Exception {
return client.ping(RequestOptions.DEFAULT);
}
};
}
}
/**
* {@link ClientUnderTestFactory} implementation for the {@link ReactiveElasticsearchClient}.
*/
static class ReactiveElasticsearchClientUnderTestFactory extends ClientUnderTestFactory {
@Override
protected String getDisplayName() {
return "ReactiveElasticsearchClient";
}
@Override
ClientUnderTest create(ClientConfiguration clientConfiguration) {
ReactiveElasticsearchClient client = ReactiveRestClients.create(clientConfiguration);
return new ClientUnderTest() {
@Override
public boolean ping() throws Exception {
return client.ping().block();
}
};
}
}
/**
* Provides the factories to use in the parameterized tests
*
* @return stream of factories
*/
static Stream<ClientUnderTestFactory> clientUnderTestFactorySource() {
return Stream.of(new RestClientUnderTestFactory(), new ReactiveElasticsearchClientUnderTestFactory());
}
}

View File

@ -83,10 +83,10 @@ public class ReactiveMockClientTestsUtils {
if (hosts.length == 1) {
delegate = new SingleNodeHostProvider(clientProvider, getInetSocketAddress(hosts[0])) {};
delegate = new SingleNodeHostProvider(clientProvider, HttpHeaders::new, getInetSocketAddress(hosts[0])) {};
} else {
delegate = new MultiNodeHostProvider(clientProvider, Arrays.stream(hosts)
delegate = new MultiNodeHostProvider(clientProvider,HttpHeaders::new, Arrays.stream(hosts)
.map(ReactiveMockClientTestsUtils::getInetSocketAddress).toArray(InetSocketAddress[]::new)) {};
}