Add parameters converter support to AbstractWebClientReactiveOAuth2AccessTokenResponseClient

This adds support for configuring NimbusJwtClientAuthenticationParametersConverter to any AbstractWebClientReactiveOAuth2AccessTokenResponseClient as an additional parameters converter, which in turns adds reactive support for jwt client authentication.

Closes gh-10146
This commit is contained in:
Steve Riesenberg 2021-09-28 16:17:36 -05:00 committed by Steve Riesenberg
parent f561499683
commit 3b564b2026
6 changed files with 690 additions and 5 deletions

View File

@ -35,6 +35,8 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.BodyInserters;
@ -70,6 +72,8 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
private Converter<T, HttpHeaders> headersConverter = this::populateTokenRequestHeaders;
private Converter<T, MultiValueMap<String, String>> parametersConverter = this::populateTokenRequestParameters;
private BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors
.oauth2AccessTokenResponse();
@ -132,7 +136,19 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
}
/**
* Creates and returns the body for the token request.
* Populates default parameters for the token request.
* @param grantRequest the grant request
* @return the parameters populated for the token request.
*/
private MultiValueMap<String, String> populateTokenRequestParameters(T grantRequest) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
return parameters;
}
/**
* Combine the results of {@code parametersConverter} and
* {@link #populateTokenRequestBody}.
*
* <p>
* This method pre-populates the body with some standard properties, and then
@ -144,9 +160,8 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
* @return the body for the token request.
*/
private BodyInserters.FormInserter<String> createTokenRequestBody(T grantRequest) {
BodyInserters.FormInserter<String> body = BodyInserters.fromFormData(OAuth2ParameterNames.GRANT_TYPE,
grantRequest.getGrantType().getValue());
return populateTokenRequestBody(grantRequest, body);
MultiValueMap<String, String> parameters = getParametersConverter().convert(grantRequest);
return populateTokenRequestBody(grantRequest, BodyInserters.fromFormData(parameters));
}
/**
@ -296,6 +311,56 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
};
}
/**
* Returns the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
* used in the OAuth 2.0 Access Token Request body.
* @return the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link MultiValueMap}
*/
final Converter<T, MultiValueMap<String, String>> getParametersConverter() {
return this.parametersConverter;
}
/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
* used in the OAuth 2.0 Access Token Request body.
* @param parametersConverter the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link MultiValueMap}
* @since 5.6
*/
public final void setParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
this.parametersConverter = parametersConverter;
}
/**
* Add (compose) the provided {@code parametersConverter} to the current
* {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
* used in the OAuth 2.0 Access Token Request body.
* @param parametersConverter the {@link Converter} to add (compose) to the current
* {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link MultiValueMap}
* @since 5.6
*/
public final void addParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
Converter<T, MultiValueMap<String, String>> currentParametersConverter = this.parametersConverter;
this.parametersConverter = (authorizationGrantRequest) -> {
MultiValueMap<String, String> parameters = currentParametersConverter.convert(authorizationGrantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
MultiValueMap<String, String> parametersToAdd = parametersConverter.convert(authorizationGrantRequest);
if (parametersToAdd != null) {
parameters.addAll(parametersToAdd);
}
return parameters;
};
}
/**
* Sets the {@link BodyExtractor} that will be used to decode the
* {@link OAuth2AccessTokenResponse}

View File

@ -16,11 +16,16 @@
package org.springframework.security.oauth2.client.endpoint;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.jwk.JWK;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -36,6 +41,7 @@ import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@ -44,6 +50,10 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.client.WebClient;
@ -112,6 +122,75 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_2", "custom-value-2");
}
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
.clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY)
.build();
// @formatter:on
// Configure Jwt client authentication converter
SecretKeySpec secretKey = new SecretKeySpec(
clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256");
JWK jwk = TestJwks.jwk(secretKey).build();
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration)).block();
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=authorization_code",
"client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer",
"client_assertion=");
}
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
// Configure Jwt client authentication converter
JWK jwk = TestJwks.DEFAULT_RSA_JWK;
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration)).block();
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=authorization_code",
"client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer",
"client_assertion=");
}
private void configureJwtClientAuthenticationConverter(Function<ClientRegistration, JWK> jwkResolver) {
NimbusJwtClientAuthenticationParametersConverter<OAuth2AuthorizationCodeGrantRequest> jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>(
jwkResolver);
this.tokenResponseClient.addParametersConverter(jwtClientAuthenticationConverter);
}
// @Test
// public void
// getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() throws
@ -261,7 +340,10 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
}
private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest() {
ClientRegistration registration = this.clientRegistration.build();
return authorizationCodeGrantRequest(this.clientRegistration.build());
}
private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest(ClientRegistration registration) {
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.clientId(registration.getClientId()).state("state")
.authorizationUri(registration.getProviderDetails().getAuthorizationUri())
@ -414,6 +496,67 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}
@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(addedParametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.addParametersConverter(addedParametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedParametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=authorization_code",
"custom-parameter-name=custom-parameter-value");
}
@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.setParametersConverter(parametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value");
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

View File

@ -20,7 +20,11 @@ import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Collections;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.jwk.JWK;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -39,6 +43,10 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;
@ -152,6 +160,75 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
"grant_type=client_credentials&client_id=client-id&client_secret=client-secret&scope=read%3Auser");
}
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}");
// @formatter:on
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
.clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY)
.build();
// @formatter:on
// Configure Jwt client authentication converter
SecretKeySpec secretKey = new SecretKeySpec(
clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256");
JWK jwk = TestJwks.jwk(secretKey).build();
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
this.client.getTokenResponse(request).block();
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=client_credentials",
"client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer",
"client_assertion=");
}
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}");
// @formatter:on
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
// Configure Jwt client authentication converter
JWK jwk = TestJwks.DEFAULT_RSA_JWK;
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
this.client.getTokenResponse(request).block();
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=client_credentials",
"client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer",
"client_assertion=");
}
private void configureJwtClientAuthenticationConverter(Function<ClientRegistration, JWK> jwkResolver) {
NimbusJwtClientAuthenticationParametersConverter<OAuth2ClientCredentialsGrantRequest> jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>(
jwkResolver);
this.client.addParametersConverter(jwtClientAuthenticationConverter);
}
@Test
public void getTokenResponseWhenNoScopeThenClientRegistrationScopesDefaulted() {
ClientRegistration registration = this.clientRegistration.build();
@ -285,6 +362,67 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}
@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.setParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.addParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(addedParametersConverter.convert(request)).willReturn(parameters);
this.client.addParametersConverter(addedParametersConverter);
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
this.client.getTokenResponse(request).block();
verify(addedParametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=client_credentials",
"custom-parameter-name=custom-parameter-value");
}
@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(request)).willReturn(parameters);
this.client.setParametersConverter(parametersConverter);
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
this.client.getTokenResponse(request).block();
verify(parametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value");
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

View File

@ -40,6 +40,8 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.client.WebClient;
@ -228,6 +230,53 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
assertThat(actualRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value");
}
@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.setParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.addParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(addedParametersConverter.convert(request)).willReturn(parameters);
this.client.addParametersConverter(addedParametersConverter);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.client.getTokenResponse(request).block();
verify(addedParametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains(
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
"custom-parameter-name=custom-parameter-value");
}
@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(Converter.class);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(request)).willReturn(parameters);
this.client.setParametersConverter(parametersConverter);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.client.getTokenResponse(request).block();
verify(parametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value");
}
@Test
public void getTokenResponseWhenBodyExtractorSetThenCalled() {
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = mock(

View File

@ -16,9 +16,14 @@
package org.springframework.security.oauth2.client.endpoint;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.jwk.JWK;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -39,6 +44,10 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyExtractor;
import static org.assertj.core.api.Assertions.assertThat;
@ -146,6 +155,79 @@ public class WebClientReactivePasswordTokenResponseClientTests {
assertThat(formParameters).contains("client_secret=client-secret");
}
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
.clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY)
.build();
// @formatter:on
// Configure Jwt client authentication converter
SecretKeySpec secretKey = new SecretKeySpec(
clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256");
JWK jwk = TestJwks.jwk(secretKey).build();
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
this.username, this.password);
this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block();
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=password",
"client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer",
"client_assertion=");
}
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
// Configure Jwt client authentication converter
JWK jwk = TestJwks.DEFAULT_RSA_JWK;
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
this.username, this.password);
this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block();
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=password",
"client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer",
"client_assertion=");
}
private void configureJwtClientAuthenticationConverter(Function<ClientRegistration, JWK> jwkResolver) {
NimbusJwtClientAuthenticationParametersConverter<OAuth2PasswordGrantRequest> jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>(
jwkResolver);
this.tokenResponseClient.addParametersConverter(jwtClientAuthenticationConverter);
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
@ -291,6 +373,69 @@ public class WebClientReactivePasswordTokenResponseClientTests {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}
@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
this.username, this.password);
Converter<OAuth2PasswordGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(addedParametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.addParametersConverter(addedParametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedParametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=password",
"custom-parameter-name=custom-parameter-value");
}
@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
this.username, this.password);
Converter<OAuth2PasswordGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.setParametersConverter(parametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value");
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

View File

@ -16,9 +16,14 @@
package org.springframework.security.oauth2.client.endpoint;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.jwk.JWK;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -42,6 +47,10 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyExtractor;
import static org.assertj.core.api.Assertions.assertThat;
@ -149,6 +158,79 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
assertThat(formParameters).contains("client_secret=client-secret");
}
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
.clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY)
.build();
// @formatter:on
// Configure Jwt client authentication converter
SecretKeySpec secretKey = new SecretKeySpec(
clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256");
JWK jwk = TestJwks.jwk(secretKey).build();
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block();
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=refresh_token",
"client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer",
"client_assertion=");
}
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
// Configure Jwt client authentication converter
JWK jwk = TestJwks.DEFAULT_RSA_JWK;
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block();
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=refresh_token",
"client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer",
"client_assertion=");
}
private void configureJwtClientAuthenticationConverter(Function<ClientRegistration, JWK> jwkResolver) {
NimbusJwtClientAuthenticationParametersConverter<OAuth2RefreshTokenGrantRequest> jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>(
jwkResolver);
this.tokenResponseClient.addParametersConverter(jwtClientAuthenticationConverter);
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
@ -294,6 +376,69 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}
@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(addedParametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.addParametersConverter(addedParametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedParametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=refresh_token",
"custom-parameter-name=custom-parameter-value");
}
@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.setParametersConverter(parametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value");
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {