Enable customizing headers in token requests

Adds the possibility to customize the headers of the access token request in AbstractWebClientReactiveOAuth2AccessTokenResponseClient, similarly to what is done in the AbstractOAuth2AuthorizationGrantRequestEntityConverter.

Closes gh-10130
This commit is contained in:
Vincent Boulaye 2021-07-20 23:05:17 +02:00 committed by Steve Riesenberg
parent 1806cebd64
commit 044157061f
5 changed files with 326 additions and 6 deletions

View File

@ -24,6 +24,7 @@ import java.util.Set;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
@ -65,6 +66,8 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
private WebClient webClient = WebClient.builder().build(); private WebClient webClient = WebClient.builder().build();
private Converter<T, HttpHeaders> headersConverter = this::populateTokenRequestHeaders;
AbstractWebClientReactiveOAuth2AccessTokenResponseClient() { AbstractWebClientReactiveOAuth2AccessTokenResponseClient() {
} }
@ -74,7 +77,12 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
// @formatter:off // @formatter:off
return Mono.defer(() -> this.webClient.post() return Mono.defer(() -> this.webClient.post()
.uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri()) .uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri())
.headers((headers) -> populateTokenRequestHeaders(grantRequest, headers)) .headers((headers) -> {
HttpHeaders headersToAdd = getHeadersConverter().convert(grantRequest);
if (headersToAdd != null) {
headers.addAll(headersToAdd);
}
})
.body(createTokenRequestBody(grantRequest)) .body(createTokenRequestBody(grantRequest))
.exchange() .exchange()
.flatMap((response) -> readTokenResponse(grantRequest, response)) .flatMap((response) -> readTokenResponse(grantRequest, response))
@ -92,9 +100,10 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
/** /**
* Populates the headers for the token request. * Populates the headers for the token request.
* @param grantRequest the grant request * @param grantRequest the grant request
* @param headers the headers to populate * @return the headers populated for the token request
*/ */
private void populateTokenRequestHeaders(T grantRequest, HttpHeaders headers) { private HttpHeaders populateTokenRequestHeaders(T grantRequest) {
HttpHeaders headers = new HttpHeaders();
ClientRegistration clientRegistration = clientRegistration(grantRequest); ClientRegistration clientRegistration = clientRegistration(grantRequest);
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
@ -104,6 +113,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
String clientSecret = encodeClientCredential(clientRegistration.getClientSecret()); String clientSecret = encodeClientCredential(clientRegistration.getClientSecret());
headers.setBasicAuth(clientId, clientSecret); headers.setBasicAuth(clientId, clientSecret);
} }
return headers;
} }
private static String encodeClientCredential(String clientCredential) { private static String encodeClientCredential(String clientCredential) {
@ -230,4 +240,55 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
this.webClient = webClient; this.webClient = webClient;
} }
/**
* Returns the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
* used in the OAuth 2.0 Access Token Request headers.
* @return the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
*/
final Converter<T, HttpHeaders> getHeadersConverter() {
return this.headersConverter;
}
/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
* used in the OAuth 2.0 Access Token Request headers.
* @param headersConverter the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
* @since 5.6
*/
public final void setHeadersConverter(Converter<T, HttpHeaders> headersConverter) {
Assert.notNull(headersConverter, "headersConverter cannot be null");
this.headersConverter = headersConverter;
}
/**
* Add (compose) the provided {@code headersConverter} to the current
* {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
* used in the OAuth 2.0 Access Token Request headers.
* @param headersConverter the {@link Converter} to add (compose) to the current
* {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link HttpHeaders}
* @since 5.6
*/
public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter) {
Assert.notNull(headersConverter, "headersConverter cannot be null");
Converter<T, HttpHeaders> currentHeadersConverter = this.headersConverter;
this.headersConverter = (authorizationGrantRequest) -> {
// Append headers using a Composite Converter
HttpHeaders headers = currentHeadersConverter.convert(authorizationGrantRequest);
if (headers == null) {
headers = new HttpHeaders();
}
HttpHeaders headersToAdd = headersConverter.convert(authorizationGrantRequest);
if (headersToAdd != null) {
headers.addAll(headersToAdd);
}
return headers;
};
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,15 +17,18 @@
package org.springframework.security.oauth2.client.endpoint; package org.springframework.security.oauth2.client.endpoint;
import java.time.Instant; import java.time.Instant;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
@ -340,4 +343,65 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange); return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange);
} }
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
final HttpHeaders headers = new HttpHeaders();
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
given(addedHeadersConverter.convert(request)).willReturn(headers);
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedHeadersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
}
@Test
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
final HttpHeaders headers = new HttpHeaders();
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
given(headersConverter.convert(request)).willReturn(headers);
this.tokenResponseClient.setHeadersConverter(headersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(headersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
}
} }

View File

@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.endpoint;
import java.net.URLEncoder; import java.net.URLEncoder;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Base64; import java.util.Base64;
import java.util.Collections;
import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.MockWebServer;
@ -27,6 +28,7 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
@ -212,4 +214,64 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
this.server.enqueue(response); this.server.enqueue(response);
} }
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.setHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.addHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
final HttpHeaders headers = new HttpHeaders();
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
given(addedHeadersConverter.convert(request)).willReturn(headers);
this.client.addHeadersConverter(addedHeadersConverter);
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
this.client.getTokenResponse(request).block();
verify(addedHeadersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
}
@Test
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
final HttpHeaders headers = new HttpHeaders();
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
given(headersConverter.convert(request)).willReturn(headers);
this.client.setHeadersConverter(headersConverter);
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
this.client.getTokenResponse(request).block();
verify(headersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
package org.springframework.security.oauth2.client.endpoint; package org.springframework.security.oauth2.client.endpoint;
import java.time.Instant; import java.time.Instant;
import java.util.Collections;
import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.MockWebServer;
@ -25,6 +26,7 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
@ -38,6 +40,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/** /**
* Tests for {@link WebClientReactivePasswordTokenResponseClient}. * Tests for {@link WebClientReactivePasswordTokenResponseClient}.
@ -213,4 +218,66 @@ public class WebClientReactivePasswordTokenResponseClientTests {
// @formatter:on // @formatter:on
} }
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
this.username, this.password);
Converter<OAuth2PasswordGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
final HttpHeaders headers = new HttpHeaders();
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
given(addedHeadersConverter.convert(request)).willReturn(headers);
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedHeadersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
}
@Test
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
this.username, this.password);
Converter<OAuth2PasswordGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
final HttpHeaders headers = new HttpHeaders();
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
given(headersConverter.convert(request)).willReturn(headers);
this.tokenResponseClient.setHeadersConverter(headersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(headersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
@ -42,6 +43,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/** /**
* Tests for {@link WebClientReactiveRefreshTokenTokenResponseClient}. * Tests for {@link WebClientReactiveRefreshTokenTokenResponseClient}.
@ -217,4 +221,66 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
// @formatter:on // @formatter:on
} }
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
final HttpHeaders headers = new HttpHeaders();
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
given(addedHeadersConverter.convert(request)).willReturn(headers);
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedHeadersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
}
@Test
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> headersConverter1 = mock(Converter.class);
final HttpHeaders headers = new HttpHeaders();
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
given(headersConverter1.convert(request)).willReturn(headers);
this.tokenResponseClient.setHeadersConverter(headersConverter1);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(headersConverter1).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
}
} }