Allow access token request parameters to override defaults

Closes gh-11298
This commit is contained in:
Steve Riesenberg 2024-04-22 17:09:05 -05:00
parent 8c2485cb47
commit f5991ae176
49 changed files with 1281 additions and 1816 deletions

View File

@ -25,7 +25,6 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
@ -75,7 +74,7 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend
private Converter<T, HttpHeaders> headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>();
private Converter<T, MultiValueMap<String, String>> parametersConverter = this::createParameters;
private Converter<T, MultiValueMap<String, String>> parametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
AbstractRestClientOAuth2AccessTokenResponseClient() {
}
@ -124,6 +123,11 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend
}
private RequestHeadersSpec<?> populateRequest(T grantRequest) {
MultiValueMap<String, String> parameters = this.parametersConverter.convert(grantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
return this.restClient.post()
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
.headers((headers) -> {
@ -132,28 +136,7 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend
headers.addAll(headersToAdd);
}
})
.body(this.parametersConverter.convert(grantRequest));
}
/**
* Returns a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body.
* @param grantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
MultiValueMap<String, String> createParameters(T grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC
.equals(clientRegistration.getClientAuthenticationMethod())) {
parameters.set(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) {
parameters.set(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
return parameters;
.body(parameters);
}
/**
@ -216,7 +199,21 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend
*/
public final void setParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
this.parametersConverter = parametersConverter;
if (parametersConverter instanceof DefaultOAuth2TokenRequestParametersConverter) {
this.parametersConverter = parametersConverter;
}
else {
Converter<T, MultiValueMap<String, String>> defaultParametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
this.parametersConverter = (authorizationGrantRequest) -> {
MultiValueMap<String, String> parameters = defaultParametersConverter
.convert(authorizationGrantRequest);
MultiValueMap<String, String> parametersToSet = parametersConverter.convert(authorizationGrantRequest);
if (parametersToSet != null) {
parameters.putAll(parametersToSet);
}
return parameters;
};
}
this.requestEntityConverter = this::populateRequest;
}

View File

@ -16,9 +16,6 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.Collections;
import java.util.Set;
import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
@ -27,16 +24,12 @@ import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec;
@ -54,6 +47,7 @@ import org.springframework.web.reactive.function.client.WebClient.RequestHeaders
*
* @param <T> type of grant request
* @author Phil Clay
* @author Steve Riesenberg
* @since 5.3
* @see <a href="https://tools.ietf.org/html/rfc6749#section-3.2">RFC-6749 Token
* Endpoint</a>
@ -72,7 +66,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
private Converter<T, HttpHeaders> headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>();
private Converter<T, MultiValueMap<String, String>> parametersConverter = this::populateTokenRequestParameters;
private Converter<T, MultiValueMap<String, String>> parametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
private BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors
.oauth2AccessTokenResponse();
@ -86,18 +80,11 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
// @formatter:off
return Mono.defer(() -> this.requestEntityConverter.convert(grantRequest)
.exchange()
.flatMap((response) -> readTokenResponse(grantRequest, response))
.flatMap((response) -> response.body(this.bodyExtractor))
);
// @formatter:on
}
/**
* Returns the {@link ClientRegistration} for the given {@code grantRequest}.
* @param grantRequest the grant request
* @return the {@link ClientRegistration} for the given {@code grantRequest}.
*/
abstract ClientRegistration clientRegistration(T grantRequest);
private RequestHeadersSpec<?> validatingPopulateRequest(T grantRequest) {
validateClientAuthenticationMethod(grantRequest);
return populateRequest(grantRequest);
@ -117,128 +104,20 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
}
private RequestHeadersSpec<?> populateRequest(T grantRequest) {
MultiValueMap<String, String> parameters = this.parametersConverter.convert(grantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
return this.webClient.post()
.uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri())
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
.headers((headers) -> {
HttpHeaders headersToAdd = getHeadersConverter().convert(grantRequest);
HttpHeaders headersToAdd = this.headersConverter.convert(grantRequest);
if (headersToAdd != null) {
headers.addAll(headersToAdd);
}
})
.body(createTokenRequestBody(grantRequest));
}
/**
* Populates default parameters for the token request.
* @param grantRequest the grant request
* @return the parameters populated for the token request.
*/
private MultiValueMap<String, String> populateTokenRequestParameters(T grantRequest) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
return parameters;
}
/**
* Combine the results of {@code parametersConverter} and
* {@link #populateTokenRequestBody}.
*
* <p>
* This method pre-populates the body with some standard properties, and then
* delegates to
* {@link #populateTokenRequestBody(AbstractOAuth2AuthorizationGrantRequest, BodyInserters.FormInserter)}
* for subclasses to further populate the body before returning.
* </p>
* @param grantRequest the grant request
* @return the body for the token request.
*/
private BodyInserters.FormInserter<String> createTokenRequestBody(T grantRequest) {
MultiValueMap<String, String> parameters = getParametersConverter().convert(grantRequest);
return populateTokenRequestBody(grantRequest, BodyInserters.fromFormData(parameters));
}
/**
* Populates the body of the token request.
*
* <p>
* By default, populates properties that are common to all grant types. Subclasses can
* extend this method to populate grant type specific properties.
* </p>
* @param grantRequest the grant request
* @param body the body to populate
* @return the populated body
*/
BodyInserters.FormInserter<String> populateTokenRequestBody(T grantRequest,
BodyInserters.FormInserter<String> body) {
ClientRegistration clientRegistration = clientRegistration(grantRequest);
if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC
.equals(clientRegistration.getClientAuthenticationMethod())) {
body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) {
body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
Set<String> scopes = scopes(grantRequest);
if (!CollectionUtils.isEmpty(scopes)) {
body.with(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " "));
}
return body;
}
/**
* Returns the scopes to include as a property in the token request.
* @param grantRequest the grant request
* @return the scopes to include as a property in the token request.
*/
abstract Set<String> scopes(T grantRequest);
/**
* Returns the scopes to include in the response if the authorization server returned
* no scopes in the response.
*
* <p>
* As per <a href="https://tools.ietf.org/html/rfc6749#section-5.1">RFC-6749 Section
* 5.1 Successful Access Token Response</a>, if AccessTokenResponse.scope is empty,
* then default to the scope originally requested by the client in the Token Request.
* </p>
* @param grantRequest the grant request
* @return the scopes to include in the response if the authorization server returned
* no scopes.
*/
Set<String> defaultScopes(T grantRequest) {
return Collections.emptySet();
}
/**
* Reads the token response from the response body.
* @param grantRequest the request for which the response was received.
* @param response the client response from which to read
* @return the token response from the response body.
*/
private Mono<OAuth2AccessTokenResponse> readTokenResponse(T grantRequest, ClientResponse response) {
return response.body(this.bodyExtractor)
.map((tokenResponse) -> populateTokenResponse(grantRequest, tokenResponse));
}
/**
* Populates the given {@link OAuth2AccessTokenResponse} with additional details from
* the grant request.
* @param grantRequest the request for which the response was received.
* @param tokenResponse the original token response
* @return a token response optionally populated with additional details from the
* request.
*/
OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessTokenResponse tokenResponse) {
if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) {
Set<String> defaultScopes = defaultScopes(grantRequest);
// @formatter:off
tokenResponse = OAuth2AccessTokenResponse
.withResponse(tokenResponse)
.scopes(defaultScopes)
.build();
// @formatter:on
}
return tokenResponse;
.body(BodyInserters.fromFormData(parameters));
}
/**
@ -247,22 +126,11 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
* @param webClient the {@link WebClient} used when requesting the Access Token
* Response
*/
public void setWebClient(WebClient webClient) {
public final void setWebClient(WebClient webClient) {
Assert.notNull(webClient, "webClient cannot be null");
this.webClient = webClient;
}
/**
* Returns the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
* used in the OAuth 2.0 Access Token Request headers.
* @return the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
*/
final Converter<T, HttpHeaders> getHeadersConverter() {
return this.headersConverter;
}
/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
@ -305,17 +173,6 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
this.requestEntityConverter = this::populateRequest;
}
/**
* Returns the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
* used in the OAuth 2.0 Access Token Request body.
* @return the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link MultiValueMap}
*/
final Converter<T, MultiValueMap<String, String>> getParametersConverter() {
return this.parametersConverter;
}
/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
@ -326,7 +183,21 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
*/
public final void setParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
this.parametersConverter = parametersConverter;
if (parametersConverter instanceof DefaultOAuth2TokenRequestParametersConverter) {
this.parametersConverter = parametersConverter;
}
else {
Converter<T, MultiValueMap<String, String>> defaultParametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
this.parametersConverter = (authorizationGrantRequest) -> {
MultiValueMap<String, String> parameters = defaultParametersConverter
.convert(authorizationGrantRequest);
MultiValueMap<String, String> parametersToSet = parametersConverter.convert(authorizationGrantRequest);
if (parametersToSet != null) {
parameters.putAll(parametersToSet);
}
return parameters;
};
}
this.requestEntityConverter = this::populateRequest;
}

View File

@ -0,0 +1,126 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.util.function.Consumer;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* Default {@link Converter} used to convert an
* {@link AbstractOAuth2AuthorizationGrantRequest} to the default {@link MultiValueMap
* parameters} of an OAuth 2.0 Access Token Request.
* <p>
* This implementation provides grant-type specific parameters for the following grant
* types:
*
* <ul>
* <li>{@code authorization_code}</li>
* <li>{@code refresh_token}</li>
* <li>{@code client_credentials}</li>
* <li>{@code password}</li>
* <li>{@code urn:ietf:params:oauth:grant-type:jwt-bearer}</li>
* <li>{@code urn:ietf:params:oauth:grant-type:token-exchange}</li>
* </ul>
*
* In addition, the following default parameters are provided:
*
* <ul>
* <li>{@code grant_type} - always provided</li>
* <li>{@code client_id} - provided unless the {@code clientAuthenticationMethod} is
* {@code client_secret_basic}</li>
* <li>{@code client_secret} - provided when the {@code clientAuthenticationMethod} is
* {@code client_secret_post}</li>
* </ul>
*
* @param <T> type of grant request
* @author Steve Riesenberg
* @since 6.4
* @see AbstractWebClientReactiveOAuth2AccessTokenResponseClient
* @see AbstractRestClientOAuth2AccessTokenResponseClient
*/
public final class DefaultOAuth2TokenRequestParametersConverter<T extends AbstractOAuth2AuthorizationGrantRequest>
implements Converter<T, MultiValueMap<String, String>> {
private final Converter<T, MultiValueMap<String, String>> defaultParametersConverter = createDefaultParametersConverter();
private Consumer<MultiValueMap<String, String>> parametersCustomizer = (parameters) -> {
};
/**
* Sets the {@link Consumer} used for customizing the OAuth 2.0 Access Token
* parameters, which allows for parameters to be added, overwritten or removed.
* @param parametersCustomizer the {@link Consumer} to customize the parameters
*/
public void setParametersCustomizer(Consumer<MultiValueMap<String, String>> parametersCustomizer) {
Assert.notNull(parametersCustomizer, "parametersCustomizer cannot be null");
this.parametersCustomizer = parametersCustomizer;
}
@Override
public MultiValueMap<String, String> convert(T grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC
.equals(clientRegistration.getClientAuthenticationMethod())) {
parameters.set(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) {
parameters.set(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
MultiValueMap<String, String> defaultParameters = this.defaultParametersConverter.convert(grantRequest);
if (defaultParameters != null) {
parameters.addAll(defaultParameters);
}
this.parametersCustomizer.accept(parameters);
return parameters;
}
private static <T extends AbstractOAuth2AuthorizationGrantRequest> Converter<T, MultiValueMap<String, String>> createDefaultParametersConverter() {
return (grantRequest) -> {
if (grantRequest instanceof OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) {
return OAuth2AuthorizationCodeGrantRequest.defaultParameters(authorizationCodeGrantRequest);
}
else if (grantRequest instanceof OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) {
return OAuth2ClientCredentialsGrantRequest.defaultParameters(clientCredentialsGrantRequest);
}
else if (grantRequest instanceof OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) {
return OAuth2RefreshTokenGrantRequest.defaultParameters(refreshTokenGrantRequest);
}
else if (grantRequest instanceof OAuth2PasswordGrantRequest passwordGrantRequest) {
return OAuth2PasswordGrantRequest.defaultParameters(passwordGrantRequest);
}
else if (grantRequest instanceof JwtBearerGrantRequest jwtBearerGrantRequest) {
return JwtBearerGrantRequest.defaultParameters(jwtBearerGrantRequest);
}
else if (grantRequest instanceof TokenExchangeGrantRequest tokenExchangeGrantRequest) {
return TokenExchangeGrantRequest.defaultParameters(tokenExchangeGrantRequest);
}
return null;
};
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,8 +18,13 @@ package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* A JWT Bearer Grant request that holds a {@link Jwt} assertion.
@ -57,4 +62,21 @@ public class JwtBearerGrantRequest extends AbstractOAuth2AuthorizationGrantReque
return this.jwt;
}
/**
* Populate default parameters for the JWT Bearer Grant.
* @param grantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
static MultiValueMap<String, String> defaultParameters(JwtBearerGrantRequest grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
parameters.set(OAuth2ParameterNames.ASSERTION, grantRequest.getJwt().getTokenValue());
return parameters;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -37,7 +37,9 @@ import org.springframework.util.StringUtils;
* @see RequestEntity
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7523#section-2.1">Section
* 2.1 Using JWTs as Authorization Grants</a>
* @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead
*/
@Deprecated(since = "6.4")
public class JwtBearerGrantRequestEntityConverter
extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<JwtBearerGrantRequest> {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,7 +19,11 @@ package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* An OAuth 2.0 Authorization Code Grant request that holds an Authorization Code
@ -60,4 +64,26 @@ public class OAuth2AuthorizationCodeGrantRequest extends AbstractOAuth2Authoriza
return this.authorizationExchange;
}
/**
* Populate default parameters for the Authorization Code Grant.
* @param grantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
static MultiValueMap<String, String> defaultParameters(OAuth2AuthorizationCodeGrantRequest grantRequest) {
OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode());
String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri();
if (redirectUri != null) {
parameters.set(OAuth2ParameterNames.REDIRECT_URI, redirectUri);
}
String codeVerifier = authorizationExchange.getAuthorizationRequest()
.getAttribute(PkceParameterNames.CODE_VERIFIER);
if (codeVerifier != null) {
parameters.set(PkceParameterNames.CODE_VERIFIER, codeVerifier);
}
return parameters;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,7 +36,9 @@ import org.springframework.util.MultiValueMap;
* @see AbstractOAuth2AuthorizationGrantRequestEntityConverter
* @see OAuth2AuthorizationCodeGrantRequest
* @see RequestEntity
* @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead
*/
@Deprecated(since = "6.4")
public class OAuth2AuthorizationCodeGrantRequestEntityConverter
extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<OAuth2AuthorizationCodeGrantRequest> {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,7 +18,12 @@ package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* An OAuth 2.0 Client Credentials Grant request that holds the client's credentials in
@ -45,4 +50,20 @@ public class OAuth2ClientCredentialsGrantRequest extends AbstractOAuth2Authoriza
"clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS");
}
/**
* Populate default parameters for the Client Credentials Grant.
* @param grantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
static MultiValueMap<String, String> defaultParameters(OAuth2ClientCredentialsGrantRequest grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
return parameters;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,7 +36,9 @@ import org.springframework.util.StringUtils;
* @see AbstractOAuth2AuthorizationGrantRequestEntityConverter
* @see OAuth2ClientCredentialsGrantRequest
* @see RequestEntity
* @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead
*/
@Deprecated(since = "6.4")
public class OAuth2ClientCredentialsGrantRequestEntityConverter
extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<OAuth2ClientCredentialsGrantRequest> {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,7 +18,12 @@ package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* An OAuth 2.0 Resource Owner Password Credentials Grant request that holds the resource
@ -74,4 +79,22 @@ public class OAuth2PasswordGrantRequest extends AbstractOAuth2AuthorizationGrant
return this.password;
}
/**
* Populate default parameters for the Password Grant.
* @param grantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
static MultiValueMap<String, String> defaultParameters(OAuth2PasswordGrantRequest grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
parameters.set(OAuth2ParameterNames.USERNAME, grantRequest.getUsername());
parameters.set(OAuth2ParameterNames.PASSWORD, grantRequest.getPassword());
return parameters;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,7 +36,9 @@ import org.springframework.util.StringUtils;
* @see AbstractOAuth2AuthorizationGrantRequestEntityConverter
* @see OAuth2PasswordGrantRequest
* @see RequestEntity
* @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead
*/
@Deprecated(since = "6.4")
public class OAuth2PasswordGrantRequestEntityConverter
extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<OAuth2PasswordGrantRequest> {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,7 +24,12 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* An OAuth 2.0 Refresh Token Grant request that holds the {@link OAuth2RefreshToken
@ -98,4 +103,20 @@ public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationG
return this.scopes;
}
/**
* Populate default parameters for the Refresh Token Grant.
* @param grantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
static MultiValueMap<String, String> defaultParameters(OAuth2RefreshTokenGrantRequest grantRequest) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
if (!CollectionUtils.isEmpty(grantRequest.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(grantRequest.getScopes(), " "));
}
parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, grantRequest.getRefreshToken().getTokenValue());
return parameters;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,7 +36,9 @@ import org.springframework.util.StringUtils;
* @see AbstractOAuth2AuthorizationGrantRequestEntityConverter
* @see OAuth2RefreshTokenGrantRequest
* @see RequestEntity
* @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead
*/
@Deprecated(since = "6.4")
public class OAuth2RefreshTokenGrantRequestEntityConverter
extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<OAuth2RefreshTokenGrantRequest> {

View File

@ -17,10 +17,6 @@
package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.util.MultiValueMap;
/**
* An implementation of {@link OAuth2AccessTokenResponseClient} that &quot;exchanges&quot;
@ -43,21 +39,4 @@ import org.springframework.util.MultiValueMap;
public final class RestClientAuthorizationCodeTokenResponseClient
extends AbstractRestClientOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {
@Override
MultiValueMap<String, String> createParameters(OAuth2AuthorizationCodeGrantRequest grantRequest) {
OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange();
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
parameters.set(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode());
String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri();
if (redirectUri != null) {
parameters.set(OAuth2ParameterNames.REDIRECT_URI, redirectUri);
}
String codeVerifier = authorizationExchange.getAuthorizationRequest()
.getAttribute(PkceParameterNames.CODE_VERIFIER);
if (codeVerifier != null) {
parameters.set(PkceParameterNames.CODE_VERIFIER, codeVerifier);
}
return parameters;
}
}

View File

@ -16,12 +16,7 @@
package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* An implementation of {@link OAuth2AccessTokenResponseClient} that &quot;exchanges&quot;
@ -42,15 +37,4 @@ import org.springframework.util.StringUtils;
public final class RestClientClientCredentialsTokenResponseClient
extends AbstractRestClientOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {
@Override
MultiValueMap<String, String> createParameters(OAuth2ClientCredentialsGrantRequest grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
return parameters;
}
}

View File

@ -16,12 +16,7 @@
package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* An implementation of {@link OAuth2AccessTokenResponseClient} that &quot;exchanges&quot;
@ -40,16 +35,4 @@ import org.springframework.util.StringUtils;
public final class RestClientJwtBearerTokenResponseClient
extends AbstractRestClientOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> {
@Override
MultiValueMap<String, String> createParameters(JwtBearerGrantRequest grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
parameters.set(OAuth2ParameterNames.ASSERTION, grantRequest.getJwt().getTokenValue());
return parameters;
}
}

View File

@ -17,10 +17,7 @@
package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* An implementation of {@link OAuth2AccessTokenResponseClient} that &quot;exchanges&quot;
@ -43,17 +40,6 @@ public final class RestClientRefreshTokenTokenResponseClient
return populateTokenResponse(grantRequest, accessTokenResponse);
}
@Override
MultiValueMap<String, String> createParameters(OAuth2RefreshTokenGrantRequest grantRequest) {
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
if (!CollectionUtils.isEmpty(grantRequest.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(grantRequest.getScopes(), " "));
}
parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, grantRequest.getRefreshToken().getTokenValue());
return parameters;
}
private OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest,
OAuth2AccessTokenResponse accessTokenResponse) {
if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())

View File

@ -16,14 +16,7 @@
package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* An implementation of {@link OAuth2AccessTokenResponseClient} that &quot;exchanges&quot;
@ -43,32 +36,4 @@ import org.springframework.util.StringUtils;
public final class RestClientTokenExchangeTokenResponseClient
extends AbstractRestClientOAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> {
private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token";
private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt";
@Override
MultiValueMap<String, String> createParameters(TokenExchangeGrantRequest grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
parameters.set(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE);
OAuth2Token subjectToken = grantRequest.getSubjectToken();
parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue());
parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken));
OAuth2Token actorToken = grantRequest.getActorToken();
if (actorToken != null) {
parameters.set(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue());
parameters.set(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken));
}
return parameters;
}
private static String tokenType(OAuth2Token token) {
return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE;
}
}

View File

@ -19,7 +19,13 @@ package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* A Token Exchange Grant request that holds the {@link OAuth2Token subject token} and
@ -39,6 +45,10 @@ import org.springframework.util.Assert;
*/
public class TokenExchangeGrantRequest extends AbstractOAuth2AuthorizationGrantRequest {
private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token";
private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt";
private final OAuth2Token subjectToken;
private final OAuth2Token actorToken;
@ -75,4 +85,33 @@ public class TokenExchangeGrantRequest extends AbstractOAuth2AuthorizationGrantR
return this.actorToken;
}
/**
* Populate default parameters for the Token Exchange Grant.
* @param grantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
static MultiValueMap<String, String> defaultParameters(TokenExchangeGrantRequest grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
parameters.set(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE);
OAuth2Token subjectToken = grantRequest.getSubjectToken();
parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue());
parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken));
OAuth2Token actorToken = grantRequest.getActorToken();
if (actorToken != null) {
parameters.set(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue());
parameters.set(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken));
}
return parameters;
}
private static String tokenType(OAuth2Token token) {
return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE;
}
}

View File

@ -39,7 +39,9 @@ import org.springframework.util.StringUtils;
* @see RequestEntity
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc8693#section-1.1">Section
* 1.1 Delegation vs. Impersonation Semantics</a>
* @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead
*/
@Deprecated(since = "6.4")
public class TokenExchangeGrantRequestEntityConverter
extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<TokenExchangeGrantRequest> {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,16 +16,7 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.Collections;
import java.util.Set;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.web.reactive.function.BodyInserters;
/**
* An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} that
@ -55,33 +46,4 @@ import org.springframework.web.reactive.function.BodyInserters;
public class WebClientReactiveAuthorizationCodeTokenResponseClient
extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {
@Override
ClientRegistration clientRegistration(OAuth2AuthorizationCodeGrantRequest grantRequest) {
return grantRequest.getClientRegistration();
}
@Override
Set<String> scopes(OAuth2AuthorizationCodeGrantRequest grantRequest) {
return Collections.emptySet();
}
@Override
BodyInserters.FormInserter<String> populateTokenRequestBody(OAuth2AuthorizationCodeGrantRequest grantRequest,
BodyInserters.FormInserter<String> body) {
super.populateTokenRequestBody(grantRequest, body);
OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange();
OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse();
body.with(OAuth2ParameterNames.CODE, authorizationResponse.getCode());
String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri();
if (redirectUri != null) {
body.with(OAuth2ParameterNames.REDIRECT_URI, redirectUri);
}
String codeVerifier = authorizationExchange.getAuthorizationRequest()
.getAttribute(PkceParameterNames.CODE_VERIFIER);
if (codeVerifier != null) {
body.with(PkceParameterNames.CODE_VERIFIER, codeVerifier);
}
return body;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,9 +16,6 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.Set;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
/**
@ -44,14 +41,4 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
public class WebClientReactiveClientCredentialsTokenResponseClient
extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {
@Override
ClientRegistration clientRegistration(OAuth2ClientCredentialsGrantRequest grantRequest) {
return grantRequest.getClientRegistration();
}
@Override
Set<String> scopes(OAuth2ClientCredentialsGrantRequest grantRequest) {
return grantRequest.getClientRegistration().getScopes();
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,13 +16,8 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.Set;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
/**
@ -44,21 +39,4 @@ import org.springframework.web.reactive.function.client.WebClient;
public final class WebClientReactiveJwtBearerTokenResponseClient
extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> {
@Override
ClientRegistration clientRegistration(JwtBearerGrantRequest grantRequest) {
return grantRequest.getClientRegistration();
}
@Override
Set<String> scopes(JwtBearerGrantRequest grantRequest) {
return grantRequest.getClientRegistration().getScopes();
}
@Override
BodyInserters.FormInserter<String> populateTokenRequestBody(JwtBearerGrantRequest grantRequest,
BodyInserters.FormInserter<String> body) {
return super.populateTokenRequestBody(grantRequest, body).with(OAuth2ParameterNames.ASSERTION,
grantRequest.getJwt().getTokenValue());
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,13 +16,8 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.Set;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
/**
@ -51,22 +46,4 @@ import org.springframework.web.reactive.function.client.WebClient;
public final class WebClientReactivePasswordTokenResponseClient
extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> {
@Override
ClientRegistration clientRegistration(OAuth2PasswordGrantRequest grantRequest) {
return grantRequest.getClientRegistration();
}
@Override
Set<String> scopes(OAuth2PasswordGrantRequest grantRequest) {
return grantRequest.getClientRegistration().getScopes();
}
@Override
BodyInserters.FormInserter<String> populateTokenRequestBody(OAuth2PasswordGrantRequest grantRequest,
BodyInserters.FormInserter<String> body) {
return super.populateTokenRequestBody(grantRequest, body)
.with(OAuth2ParameterNames.USERNAME, grantRequest.getUsername())
.with(OAuth2ParameterNames.PASSWORD, grantRequest.getPassword());
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,14 +16,11 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.Set;
import reactor.core.publisher.Mono;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.CollectionUtils;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
/**
@ -44,29 +41,12 @@ public final class WebClientReactiveRefreshTokenTokenResponseClient
extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> {
@Override
ClientRegistration clientRegistration(OAuth2RefreshTokenGrantRequest grantRequest) {
return grantRequest.getClientRegistration();
public Mono<OAuth2AccessTokenResponse> getTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest) {
return super.getTokenResponse(grantRequest)
.map((accessTokenResponse) -> populateTokenResponse(grantRequest, accessTokenResponse));
}
@Override
Set<String> scopes(OAuth2RefreshTokenGrantRequest grantRequest) {
return grantRequest.getScopes();
}
@Override
Set<String> defaultScopes(OAuth2RefreshTokenGrantRequest grantRequest) {
return grantRequest.getAccessToken().getScopes();
}
@Override
BodyInserters.FormInserter<String> populateTokenRequestBody(OAuth2RefreshTokenGrantRequest grantRequest,
BodyInserters.FormInserter<String> body) {
return super.populateTokenRequestBody(grantRequest, body).with(OAuth2ParameterNames.REFRESH_TOKEN,
grantRequest.getRefreshToken().getTokenValue());
}
@Override
OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest,
private OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest,
OAuth2AccessTokenResponse accessTokenResponse) {
if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())
&& accessTokenResponse.getRefreshToken() != null) {
@ -75,7 +55,7 @@ public final class WebClientReactiveRefreshTokenTokenResponseClient
OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse
.withResponse(accessTokenResponse);
if (CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())) {
tokenResponseBuilder.scopes(defaultScopes(grantRequest));
tokenResponseBuilder.scopes(grantRequest.getAccessToken().getScopes());
}
if (accessTokenResponse.getRefreshToken() == null) {
// Reuse existing refresh token

View File

@ -16,15 +16,8 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.Set;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
/**
@ -46,38 +39,4 @@ import org.springframework.web.reactive.function.client.WebClient;
public final class WebClientReactiveTokenExchangeTokenResponseClient
extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> {
private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token";
private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt";
@Override
ClientRegistration clientRegistration(TokenExchangeGrantRequest grantRequest) {
return grantRequest.getClientRegistration();
}
@Override
Set<String> scopes(TokenExchangeGrantRequest grantRequest) {
return grantRequest.getClientRegistration().getScopes();
}
@Override
BodyInserters.FormInserter<String> populateTokenRequestBody(TokenExchangeGrantRequest grantRequest,
BodyInserters.FormInserter<String> body) {
super.populateTokenRequestBody(grantRequest, body);
body.with(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE);
OAuth2Token subjectToken = grantRequest.getSubjectToken();
body.with(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue());
body.with(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken));
OAuth2Token actorToken = grantRequest.getActorToken();
if (actorToken != null) {
body.with(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue());
body.with(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken));
}
return body;
}
private static String tokenType(OAuth2Token token) {
return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE;
}
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import okhttp3.mockwebserver.MockResponse;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
/**
* @author Steve Riesenberg
*/
public final class MockResponses {
private MockResponses() {
}
public static MockResponse json(String path) {
try {
String json = new ClassPathResource(path).getContentAsString(StandardCharsets.UTF_8);
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
}
catch (IOException ex) {
throw new RuntimeException("Unable to read %s as a classpath resource".formatted(path), ex);
}
}
}

View File

@ -0,0 +1,228 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link DefaultOAuth2TokenRequestParametersConverter}.
*
* @author Steve Riesenberg
*/
public class DefaultOAuth2TokenRequestParametersConverterTests {
private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token";
private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt";
private ClientRegistration.Builder clientRegistration;
@BeforeEach
public void setUp() {
this.clientRegistration = TestClientRegistrations.clientRegistration()
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.clientId("client-1")
.clientSecret("secret")
.scope("read", "write");
}
@Test
public void convertWhenGrantRequestIsAuthorizationCodeThenParametersProvided() {
ClientRegistration clientRegistration = this.clientRegistration
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.build();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.clientId("client-1")
.state("state")
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(clientRegistration.getRedirectUri())
.attributes(Map.of(PkceParameterNames.CODE_VERIFIER, "code-verifier"))
.scopes(clientRegistration.getScopes())
.build();
OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse.success("code")
.state("state")
.redirectUri(clientRegistration.getRedirectUri())
.build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
authorizationResponse);
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
authorizationExchange);
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<OAuth2AuthorizationCodeGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
MultiValueMap<String, String> parameters = parametersConverter.convert(grantRequest);
assertThat(parameters).hasSize(6);
assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE))
.containsExactly(AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET))
.containsExactly(clientRegistration.getClientSecret());
assertThat(parameters.get(OAuth2ParameterNames.CODE)).containsExactly(authorizationResponse.getCode());
assertThat(parameters.get(OAuth2ParameterNames.REDIRECT_URI))
.containsExactly(clientRegistration.getRedirectUri());
assertThat(parameters.get(PkceParameterNames.CODE_VERIFIER))
.containsExactly(authorizationRequest.<String>getAttribute(PkceParameterNames.CODE_VERIFIER));
}
@Test
public void convertWhenGrantRequestIsClientCredentialsThenParametersProvided() {
ClientRegistration clientRegistration = this.clientRegistration
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<OAuth2ClientCredentialsGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
MultiValueMap<String, String> parameters = parametersConverter.convert(grantRequest);
assertThat(parameters).hasSize(4);
assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE))
.containsExactly(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET))
.containsExactly(clientRegistration.getClientSecret());
assertThat(parameters.get(OAuth2ParameterNames.SCOPE))
.containsExactly(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
@Test
public void convertWhenGrantRequestIsRefreshTokenThenParametersProvided() {
ClientRegistration clientRegistration = this.clientRegistration
.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
.build();
OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write");
OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
accessToken, refreshToken, clientRegistration.getScopes());
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<OAuth2RefreshTokenGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
MultiValueMap<String, String> parameters = parametersConverter.convert(grantRequest);
assertThat(parameters).hasSize(5);
assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE))
.containsExactly(AuthorizationGrantType.REFRESH_TOKEN.getValue());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET))
.containsExactly(clientRegistration.getClientSecret());
assertThat(parameters.get(OAuth2ParameterNames.REFRESH_TOKEN)).containsExactly(refreshToken.getTokenValue());
assertThat(parameters.get(OAuth2ParameterNames.SCOPE))
.containsExactly(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
@Test
public void convertWhenGrantRequestIsPasswordThenParametersProvided() {
ClientRegistration clientRegistration = this.clientRegistration
.authorizationGrantType(AuthorizationGrantType.PASSWORD)
.build();
OAuth2PasswordGrantRequest grantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user",
"password");
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<OAuth2PasswordGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
MultiValueMap<String, String> parameters = parametersConverter.convert(grantRequest);
assertThat(parameters).hasSize(6);
assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE))
.containsExactly(AuthorizationGrantType.PASSWORD.getValue());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET))
.containsExactly(clientRegistration.getClientSecret());
assertThat(parameters.get(OAuth2ParameterNames.USERNAME)).containsExactly(grantRequest.getUsername());
assertThat(parameters.get(OAuth2ParameterNames.PASSWORD)).containsExactly(grantRequest.getPassword());
assertThat(parameters.get(OAuth2ParameterNames.SCOPE))
.containsExactly(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
@Test
public void convertWhenGrantRequestIsJwtBearerThenParametersProvided() {
ClientRegistration clientRegistration = this.clientRegistration
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.build();
Jwt jwt = TestJwts.jwt().build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, jwt);
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<JwtBearerGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
MultiValueMap<String, String> parameters = parametersConverter.convert(grantRequest);
assertThat(parameters).hasSize(5);
assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE))
.containsExactly(AuthorizationGrantType.JWT_BEARER.getValue());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET))
.containsExactly(clientRegistration.getClientSecret());
assertThat(parameters.get(OAuth2ParameterNames.ASSERTION)).containsExactly(jwt.getTokenValue());
assertThat(parameters.get(OAuth2ParameterNames.SCOPE))
.containsExactly(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
@Test
public void convertWhenGrantRequestIsTokenExchangeThenParametersProvided() {
ClientRegistration clientRegistration = this.clientRegistration
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.build();
OAuth2Token subjectToken = TestOAuth2AccessTokens.scopes("read", "write");
OAuth2Token actorToken = TestJwts.jwt().build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, subjectToken,
actorToken);
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<TokenExchangeGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
MultiValueMap<String, String> parameters = parametersConverter.convert(grantRequest);
assertThat(parameters).hasSize(9);
assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE))
.containsExactly(AuthorizationGrantType.TOKEN_EXCHANGE.getValue());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId());
assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET))
.containsExactly(clientRegistration.getClientSecret());
assertThat(parameters.get(OAuth2ParameterNames.SCOPE))
.containsExactly(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
assertThat(parameters.get(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)).containsExactly(ACCESS_TOKEN_TYPE_VALUE);
assertThat(parameters.get(OAuth2ParameterNames.SUBJECT_TOKEN)).containsExactly(subjectToken.getTokenValue());
assertThat(parameters.get(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).containsExactly(ACCESS_TOKEN_TYPE_VALUE);
assertThat(parameters.get(OAuth2ParameterNames.ACTOR_TOKEN)).containsExactly(actorToken.getTokenValue());
assertThat(parameters.get(OAuth2ParameterNames.ACTOR_TOKEN_TYPE)).containsExactly(JWT_TOKEN_TYPE_VALUE);
}
}

View File

@ -21,6 +21,7 @@ import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.function.Consumer;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
@ -34,6 +35,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -54,6 +56,7 @@ import org.springframework.web.client.RestClient;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
@ -80,13 +83,12 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
// @formatter:off
this.clientRegistration = TestClientRegistrations.clientRegistration()
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.tokenUri(tokenUri)
.scope("read", "write");
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.tokenUri(tokenUri)
.scope("read", "write");
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.clientId("client-1")
@ -99,7 +101,6 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
.state("state")
.redirectUri(clientRegistration.getRedirectUri())
.build();
// @formatter:on
this.authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse);
}
@ -164,15 +165,7 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
@ -201,14 +194,7 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
@ -219,14 +205,7 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
@ -235,19 +214,17 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
this.tokenResponseClient.getTokenResponse(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("client_id=client-1", "client_secret=secret");
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.CLIENT_SECRET, "secret")
);
// @formatter:on
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("invalid-token-type-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
@ -262,15 +239,7 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
@ -281,14 +250,7 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
@ -313,8 +275,7 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500));
this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest request = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
@ -328,8 +289,7 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest request = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
@ -371,18 +331,11 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -396,18 +349,11 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -421,19 +367,11 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -442,20 +380,13 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
verify(parametersConverter).convert(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value");
assertThat(formParameters).contains(param("custom-parameter-name", "custom-parameter-value"));
}
@Test
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
@ -463,7 +394,6 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.CODE, "custom-code");
parameters.set(OAuth2ParameterNames.REDIRECT_URI, "custom-uri");
// The client_id parameter is omitted for testing purposes
this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters);
this.tokenResponseClient.getTokenResponse(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
@ -471,27 +401,20 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.CODE, "custom-code"),
param(OAuth2ParameterNames.REDIRECT_URI, "custom-uri"));
param(OAuth2ParameterNames.REDIRECT_URI, "custom-uri")
);
// @formatter:on
assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID);
}
@Test
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -510,15 +433,25 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
}
@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
DefaultOAuth2TokenRequestParametersConverter<OAuth2AuthorizationCodeGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}
@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
this.server.enqueue(MockResponses.json("access-token-response.json"));
RestClient restClient = RestClient.builder().messageConverters((messageConverters) -> {
messageConverters.add(0, new FormHttpMessageConverter());
messageConverters.add(1, new OAuth2AccessTokenResponseHttpMessageConverter());
@ -532,10 +465,6 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
verify(customClient).post();
}
private static MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}

View File

@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.Set;
import java.util.function.Consumer;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
@ -35,6 +36,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -53,6 +55,7 @@ import org.springframework.web.client.RestClient;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
@ -77,14 +80,12 @@ public class RestClientClientCredentialsTokenResponseClientTests {
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
// @formatter:off
this.clientRegistration = TestClientRegistrations.clientCredentials()
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.tokenUri(tokenUri)
.scope("read", "write");
// @formatter:on
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.tokenUri(tokenUri)
.scope("read", "write");
}
@AfterEach
@ -148,15 +149,7 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
ClientRegistration clientRegistration = this.clientRegistration.build();
Set<String> scopes = clientRegistration.getScopes();
@ -185,14 +178,7 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
this.tokenResponseClient.getTokenResponse(grantRequest);
@ -202,14 +188,7 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
@ -217,19 +196,17 @@ public class RestClientClientCredentialsTokenResponseClientTests {
this.tokenResponseClient.getTokenResponse(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("client_id=client-1", "client_secret=secret");
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.CLIENT_SECRET, "secret")
);
// @formatter:on
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("invalid-token-type-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
// @formatter:off
@ -243,15 +220,7 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest);
@ -261,14 +230,7 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest);
@ -278,14 +240,7 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenRequestDoesNotIncludeScopeThenAccessTokenHasNoScope() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("no-scope")
.clientId("client-1")
@ -328,8 +283,7 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500));
this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
// @formatter:off
@ -342,8 +296,7 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
// @formatter:off
@ -382,17 +335,10 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -406,17 +352,10 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -430,18 +369,10 @@ public class RestClientClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -450,28 +381,19 @@ public class RestClientClientCredentialsTokenResponseClientTests {
verify(parametersConverter).convert(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value");
assertThat(formParameters).contains(param("custom-parameter-name", "custom-parameter-value"));
}
@Test
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.SCOPE, "one two");
// The client_id parameter is omitted for testing purposes
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters);
this.tokenResponseClient.getTokenResponse(grantRequest);
@ -480,26 +402,19 @@ public class RestClientClientCredentialsTokenResponseClientTests {
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.SCOPE, "one two"));
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.SCOPE, "one two")
);
// @formatter:on
assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID);
}
@Test
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
Set<String> scopes = clientRegistration.getScopes();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -518,15 +433,24 @@ public class RestClientClientCredentialsTokenResponseClientTests {
}
@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
DefaultOAuth2TokenRequestParametersConverter<OAuth2ClientCredentialsGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}
@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
this.server.enqueue(MockResponses.json("access-token-response.json"));
RestClient restClient = RestClient.builder().messageConverters((messageConverters) -> {
messageConverters.add(0, new FormHttpMessageConverter());
messageConverters.add(1, new OAuth2AccessTokenResponseHttpMessageConverter());
@ -539,10 +463,6 @@ public class RestClientClientCredentialsTokenResponseClientTests {
verify(customClient).post();
}
private static MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}

View File

@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.Set;
import java.util.function.Consumer;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
@ -34,6 +35,7 @@ import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -53,6 +55,7 @@ import org.springframework.web.client.RestClient;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -78,14 +81,12 @@ public class RestClientJwtBearerTokenResponseClientTests {
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
// @formatter:off
this.clientRegistration = TestClientRegistrations.clientCredentials()
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.tokenUri(tokenUri)
.scope("read", "write");
// @formatter:on
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.tokenUri(tokenUri)
.scope("read", "write");
this.jwtAssertion = TestJwts.jwt().build();
}
@ -150,15 +151,7 @@ public class RestClientJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
ClientRegistration clientRegistration = this.clientRegistration.build();
Set<String> scopes = clientRegistration.getScopes();
@ -188,14 +181,7 @@ public class RestClientJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
this.tokenResponseClient.getTokenResponse(grantRequest);
@ -205,14 +191,7 @@ public class RestClientJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
@ -220,19 +199,17 @@ public class RestClientJwtBearerTokenResponseClientTests {
this.tokenResponseClient.getTokenResponse(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("client_id=client-1", "client_secret=secret");
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.CLIENT_SECRET, "secret")
);
// @formatter:on
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("invalid-token-type-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
// @formatter:off
@ -246,15 +223,7 @@ public class RestClientJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest);
@ -264,14 +233,7 @@ public class RestClientJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest);
@ -294,8 +256,7 @@ public class RestClientJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500));
this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
// @formatter:off
@ -308,8 +269,7 @@ public class RestClientJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
// @formatter:off
@ -348,17 +308,10 @@ public class RestClientJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<JwtBearerGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -372,17 +325,10 @@ public class RestClientJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<JwtBearerGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -397,14 +343,7 @@ public class RestClientJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
@ -419,26 +358,20 @@ public class RestClientJwtBearerTokenResponseClientTests {
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.ASSERTION, "custom-assertion"),
param(OAuth2ParameterNames.SCOPE, "one two"));
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.SCOPE, "one two"),
param(OAuth2ParameterNames.ASSERTION, "custom-assertion")
);
// @formatter:on
assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID);
}
@Test
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(Converter.class);
Converter<JwtBearerGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -447,23 +380,16 @@ public class RestClientJwtBearerTokenResponseClientTests {
verify(parametersConverter).convert(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value");
assertThat(formParameters).contains(param("custom-parameter-name", "custom-parameter-value"));
}
@Test
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
Set<String> scopes = clientRegistration.getScopes();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(Converter.class);
Converter<JwtBearerGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -483,16 +409,25 @@ public class RestClientJwtBearerTokenResponseClientTests {
}
@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
DefaultOAuth2TokenRequestParametersConverter<JwtBearerGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
RestClient customClient = mock(RestClient.class);
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}
@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
this.server.enqueue(MockResponses.json("access-token-response.json"));
RestClient customClient = mock();
given(customClient.post()).willReturn(RestClient.builder().build().post());
this.tokenResponseClient.setRestClient(customClient);
ClientRegistration clientRegistration = this.clientRegistration.build();
@ -501,10 +436,6 @@ public class RestClientJwtBearerTokenResponseClientTests {
verify(customClient).post();
}
private static MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}

View File

@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.Set;
import java.util.function.Consumer;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
@ -35,6 +36,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -56,6 +58,7 @@ import org.springframework.web.client.RestClient;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
@ -84,14 +87,12 @@ public class RestClientRefreshTokenTokenResponseClientTests {
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
// @formatter:off
this.clientRegistration = TestClientRegistrations.clientCredentials()
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
.tokenUri(tokenUri)
.scope("read", "write");
// @formatter:on
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
.tokenUri(tokenUri)
.scope("read", "write");
this.accessToken = TestOAuth2AccessTokens.scopes("read", "write");
this.refreshToken = TestOAuth2RefreshTokens.refreshToken();
}
@ -157,15 +158,7 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
ClientRegistration clientRegistration = this.clientRegistration.build();
Set<String> scopes = clientRegistration.getScopes();
@ -196,14 +189,7 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
@ -214,14 +200,7 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
@ -230,19 +209,17 @@ public class RestClientRefreshTokenTokenResponseClientTests {
this.tokenResponseClient.getTokenResponse(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("client_id=client-1", "client_secret=secret");
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.CLIENT_SECRET, "secret")
);
// @formatter:on
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("invalid-token-type-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
@ -257,15 +234,7 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
@ -276,14 +245,7 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasRequestedScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
Set<String> scopes = clientRegistration.getScopes();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
@ -295,15 +257,7 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenRequestDoesNotIncludeScopeThenAccessTokenHasResponseScope() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
@ -341,8 +295,7 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500));
this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
@ -356,8 +309,7 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
@ -399,18 +351,11 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -424,18 +369,11 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -449,19 +387,11 @@ public class RestClientRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -470,20 +400,13 @@ public class RestClientRefreshTokenTokenResponseClientTests {
verify(parametersConverter).convert(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value");
assertThat(formParameters).contains(param("custom-parameter-name", "custom-parameter-value"));
}
@Test
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
@ -491,7 +414,6 @@ public class RestClientRefreshTokenTokenResponseClientTests {
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, "custom-token");
parameters.set(OAuth2ParameterNames.SCOPE, "one two");
// The client_id parameter is omitted for testing purposes
this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters);
this.tokenResponseClient.getTokenResponse(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
@ -499,28 +421,21 @@ public class RestClientRefreshTokenTokenResponseClientTests {
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.REFRESH_TOKEN, "custom-token"),
param(OAuth2ParameterNames.SCOPE, "one two"));
param(OAuth2ParameterNames.SCOPE, "one two")
);
// @formatter:on
assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID);
}
@Test
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
Set<String> scopes = clientRegistration.getScopes();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken, scopes);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -540,15 +455,25 @@ public class RestClientRefreshTokenTokenResponseClientTests {
}
@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
DefaultOAuth2TokenRequestParametersConverter<OAuth2RefreshTokenGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}
@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
this.server.enqueue(MockResponses.json("access-token-response.json"));
RestClient restClient = RestClient.builder().messageConverters((messageConverters) -> {
messageConverters.add(0, new FormHttpMessageConverter());
messageConverters.add(1, new OAuth2AccessTokenResponseHttpMessageConverter());
@ -562,10 +487,6 @@ public class RestClientRefreshTokenTokenResponseClientTests {
verify(customClient).post();
}
private static MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}

View File

@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.Set;
import java.util.function.Consumer;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
@ -34,6 +35,7 @@ import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -54,6 +56,7 @@ import org.springframework.web.client.RestClient;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -85,14 +88,12 @@ public class RestClientTokenExchangeTokenResponseClientTests {
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
// @formatter:off
this.clientRegistration = TestClientRegistrations.clientCredentials()
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.tokenUri(tokenUri)
.scope("read", "write");
// @formatter:on
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.tokenUri(tokenUri)
.scope("read", "write");
this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write");
this.actorToken = null;
}
@ -158,15 +159,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
ClientRegistration clientRegistration = this.clientRegistration.build();
Set<String> scopes = clientRegistration.getScopes();
@ -199,15 +192,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
this.subjectToken = TestJwts.jwt().build();
ClientRegistration clientRegistration = this.clientRegistration.build();
@ -241,15 +226,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
this.actorToken = TestOAuth2AccessTokens.noScopes();
ClientRegistration clientRegistration = this.clientRegistration.build();
@ -285,15 +262,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
this.actorToken = TestJwts.jwt().build();
ClientRegistration clientRegistration = this.clientRegistration.build();
@ -329,14 +298,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
@ -347,14 +309,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
@ -363,19 +318,17 @@ public class RestClientTokenExchangeTokenResponseClientTests {
this.tokenResponseClient.getTokenResponse(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("client_id=client-1", "client_secret=secret");
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.CLIENT_SECRET, "secret")
);
// @formatter:on
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("invalid-token-type-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
@ -390,15 +343,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
@ -409,14 +354,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
@ -440,8 +378,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500));
this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500));
TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
// @formatter:off
@ -454,8 +391,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400));
TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
// @formatter:off
@ -496,18 +432,11 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
Converter<TokenExchangeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<TokenExchangeGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -521,18 +450,11 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
Converter<TokenExchangeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<TokenExchangeGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -546,18 +468,11 @@ public class RestClientTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
Converter<TokenExchangeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(Converter.class);
Converter<TokenExchangeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -566,20 +481,13 @@ public class RestClientTokenExchangeTokenResponseClientTests {
verify(parametersConverter).convert(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value");
assertThat(formParameters).contains(param("custom-parameter-name", "custom-parameter-value"));
}
@Test
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
@ -587,7 +495,6 @@ public class RestClientTokenExchangeTokenResponseClientTests {
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.SCOPE, "one two");
parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN, "custom-token");
// The client_id parameter is omitted for testing purposes
this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters);
this.tokenResponseClient.getTokenResponse(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
@ -595,27 +502,22 @@ public class RestClientTokenExchangeTokenResponseClientTests {
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.SCOPE, "one two"),
param(OAuth2ParameterNames.SUBJECT_TOKEN, "custom-token"));
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.SUBJECT_TOKEN, "custom-token"),
param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE),
param(OAuth2ParameterNames.SCOPE, "one two")
);
// @formatter:on
assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID);
}
@Test
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
Set<String> scopes = clientRegistration.getScopes();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
Converter<TokenExchangeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(Converter.class);
Converter<TokenExchangeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -637,16 +539,26 @@ public class RestClientTokenExchangeTokenResponseClientTests {
}
@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
DefaultOAuth2TokenRequestParametersConverter<TokenExchangeGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
RestClient customClient = mock(RestClient.class);
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}
@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
this.server.enqueue(MockResponses.json("access-token-response.json"));
RestClient customClient = mock();
given(customClient.post()).willReturn(RestClient.builder().build().post());
this.tokenResponseClient.setRestClient(customClient);
ClientRegistration clientRegistration = this.clientRegistration.build();
@ -656,10 +568,6 @@ public class RestClientTokenExchangeTokenResponseClientTests {
verify(customClient).post();
}
private static MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.endpoint;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
@ -26,7 +27,6 @@ import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.jwk.JWK;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.junit.jupiter.api.AfterEach;
@ -36,9 +36,8 @@ import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@ -48,6 +47,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jose.TestJwks;
@ -93,18 +93,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\",\n"
+ " \"refresh_token\": \"refresh-token-1234\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-openid-profile-2.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(authorizationCodeGrantRequest())
@ -125,14 +114,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
@ -158,14 +140,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
@ -194,9 +169,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n";
this.server
.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value()));
this.server.enqueue(MockResponses.json("unauthorized-client-response.json").setResponseCode(500));
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block())
.satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("unauthorized_client"))
@ -206,9 +179,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
// gh-5594
@Test
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{}";
this.server
.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value()));
this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500));
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block())
.withMessageContaining("server_error");
@ -216,14 +187,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ "\"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("invalid-token-type-response.json"));
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block())
.withMessageContaining("invalid_token_response");
@ -231,15 +195,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ "\"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-openid-profile.json"));
this.clientRegistration.scope("openid", "profile", "email", "address");
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(authorizationCodeGrantRequest())
@ -249,14 +205,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseWithNoScopes() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.clientRegistration.scope("openid", "profile", "email", "address");
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(authorizationCodeGrantRequest())
@ -285,10 +234,6 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange);
}
private MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
}
@Test
public void setWebClientNullThenIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setWebClient(null));
@ -296,18 +241,10 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test
public void setCustomWebClientThenCustomWebClientIsUsed() {
WebClient customClient = mock(WebClient.class);
WebClient customClient = mock();
given(customClient.post()).willReturn(WebClient.builder().build().post());
this.tokenResponseClient.setWebClient(customClient);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.clientRegistration.scope("openid", "profile", "email", "address");
OAuth2AccessTokenResponse response = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest())
.block();
@ -317,14 +254,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test
public void getTokenResponseWhenOAuth2AuthorizationRequestContainsPkceParametersThenTokenRequestBodyShouldContainCodeVerifier()
throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(pkceAuthorizationCodeGrantRequest()).block();
String body = this.server.takeRequest().getBody().readUtf8();
assertThat(body).isEqualTo(
@ -379,20 +309,12 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> addedHeadersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(addedHeadersConverter.convert(request)).willReturn(headers);
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedHeadersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -406,20 +328,12 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
ClientRegistration clientRegistration = request.getClientRegistration();
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
given(headersConverter.convert(request)).willReturn(headers);
this.tokenResponseClient.setHeadersConverter(headersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(headersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -440,23 +354,14 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
}
@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(addedParametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.addParametersConverter(addedParametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedParametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -465,44 +370,55 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
}
@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.setParametersConverter(parametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value");
}
@Test
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
this.server.enqueue(MockResponses.json("access-token-response.json"));
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.CODE, "custom-code");
parameters.set(OAuth2ParameterNames.REDIRECT_URI, "custom-uri");
this.tokenResponseClient.setParametersConverter((grantRequest) -> parameters);
this.tokenResponseClient.getTokenResponse(request).block();
String formParameters = this.server.takeRequest().getBody().readUtf8();
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.CLIENT_ID, "client-id"),
param(OAuth2ParameterNames.CODE, "custom-code"),
param(OAuth2ParameterNames.REDIRECT_URI, "custom-uri")
);
// @formatter:on
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
String accessTokenSuccessResponse = "{}";
WebClientReactiveAuthorizationCodeTokenResponseClient customClient = new WebClientReactiveAuthorizationCodeTokenResponseClient();
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock();
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));
customClient.setBodyExtractor(extractor);
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(authorizationCodeGrantRequest())
.block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();
@ -533,4 +449,8 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest).block());
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -37,11 +37,13 @@ import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
@ -88,15 +90,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeaderThenSuccess() throws Exception {
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n"
+ " \"scope\":\"create\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response-create.json"));
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
@ -112,15 +106,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
// gh-9610
@Test
public void getTokenResponseWhenSpecialCharactersThenSuccessWithEncodedClientCredentials() throws Exception {
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n"
+ " \"scope\":\"create\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response-create.json"));
String clientCredentialWithAnsiKeyboardSpecialCharacters = "~!@#$%^&*()_+{}|:\"<>?`-=[]\\;',./ ";
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.clientId(clientCredentialWithAnsiKeyboardSpecialCharacters)
@ -145,15 +131,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
ClientRegistration registration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n"
+ " \"scope\":\"create\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response-create.json"));
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
RecordedRequest actualRequest = this.server.takeRequest();
@ -167,13 +145,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response.json"));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
@ -200,13 +172,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response.json"));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
@ -237,14 +203,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenNoScopeThenReturnAccessTokenResponseWithNoScopes() {
ClientRegistration registration = this.clientRegistration.build();
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response.json"));
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
assertThat(response.getAccessToken().getScopes()).isEmpty();
@ -257,18 +216,11 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
@Test
public void setWebClientCustomThenCustomClientIsUsed() {
WebClient customClient = mock(WebClient.class);
WebClient customClient = mock();
given(customClient.post()).willReturn(WebClient.builder().build().post());
this.client.setWebClient(customClient);
ClientRegistration registration = this.clientRegistration.build();
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response.json"));
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
verify(customClient, atLeastOnce()).post();
@ -295,12 +247,6 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
this.server.enqueue(response);
}
private void enqueueJson(String body) {
MockResponse response = new MockResponse().setBody(body)
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE);
this.server.enqueue(response);
}
// gh-10130
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
@ -320,19 +266,12 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> addedHeadersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(addedHeadersConverter.convert(request)).willReturn(headers);
this.client.addHeadersConverter(addedHeadersConverter);
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
verify(addedHeadersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -347,19 +286,12 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
ClientRegistration clientRegistration = request.getClientRegistration();
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
given(headersConverter.convert(request)).willReturn(headers);
this.client.setHeadersConverter(headersConverter);
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
verify(headersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -380,23 +312,15 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
}
@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(addedParametersConverter.convert(request)).willReturn(parameters);
this.client.addParametersConverter(addedParametersConverter);
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
verify(addedParametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -405,38 +329,51 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
}
@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(request)).willReturn(parameters);
this.client.setParametersConverter(parametersConverter);
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
verify(parametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value");
}
@Test
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.SCOPE, "one two");
this.client.setParametersConverter((grantRequest) -> parameters);
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
String formParameters = this.server.takeRequest().getBody().readUtf8();
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.CLIENT_ID, "client-id"),
param(OAuth2ParameterNames.SCOPE, "one two")
);
// @formatter:on
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
enqueueJson("{}");
this.server.enqueue(MockResponses.json("access-token-response.json"));
WebClientReactiveClientCredentialsTokenResponseClient customClient = new WebClientReactiveClientCredentialsTokenResponseClient();
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock();
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));
@ -474,4 +411,8 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
.isThrownBy(() -> this.client.getTokenResponse(clientCredentialsGrantRequest).block());
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
package org.springframework.security.oauth2.client.endpoint;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import okhttp3.mockwebserver.MockResponse;
@ -30,6 +32,7 @@ import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -37,6 +40,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.TestJwts;
@ -82,12 +86,10 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
// @formatter:off
this.clientRegistration = TestClientRegistrations.clientCredentials()
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.tokenUri(tokenUri)
.scope("read", "write");
// @formatter:on
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.tokenUri(tokenUri)
.scope("read", "write");
this.jwtAssertion = TestJwts.jwt().build();
}
@ -150,13 +152,8 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenResponse = "{\n"
+ " \"error\": \"invalid_grant\"\n"
+ "}\n";
// @formatter:on
ClientRegistration registration = this.clientRegistration.build();
enqueueJson(accessTokenResponse);
this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400));
JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.client.getTokenResponse(request).block())
@ -166,15 +163,8 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenResponseIsNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": 3600\n"
+ "}\n";
// @formatter:on
ClientRegistration registration = this.clientRegistration.build();
enqueueJson(accessTokenResponse);
this.server.enqueue(MockResponses.json("invalid-token-type-response.json"));
JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.client.getTokenResponse(request).block())
@ -185,10 +175,10 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenWebClientSetThenCalled() {
WebClient customClient = mock(WebClient.class);
WebClient customClient = mock();
given(customClient.post()).willReturn(WebClient.builder().build().post());
this.client.setWebClient(customClient);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration registration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion);
this.client.getTokenResponse(request).block();
@ -199,12 +189,12 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<JwtBearerGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
given(headersConverter.convert(request)).willReturn(headers);
this.client.setHeadersConverter(headersConverter);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
verify(headersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -216,12 +206,12 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
Converter<JwtBearerGrantRequest, HttpHeaders> addedHeadersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(addedHeadersConverter.convert(request)).willReturn(headers);
this.client.addHeadersConverter(addedHeadersConverter);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
verify(addedHeadersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -243,16 +233,15 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
}
@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
Converter<JwtBearerGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(addedParametersConverter.convert(request)).willReturn(parameters);
this.client.addParametersConverter(addedParametersConverter);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
verify(addedParametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -262,48 +251,62 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
}
@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(Converter.class);
Converter<JwtBearerGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(request)).willReturn(parameters);
this.client.setParametersConverter(parametersConverter);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
verify(parametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value");
}
@Test
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.ASSERTION, "custom-assertion");
parameters.set(OAuth2ParameterNames.SCOPE, "one two");
this.client.setParametersConverter((grantRequest) -> parameters);
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
String formParameters = this.server.takeRequest().getBody().readUtf8();
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.CLIENT_ID, "client-id"),
param(OAuth2ParameterNames.SCOPE, "one two"),
param(OAuth2ParameterNames.ASSERTION, "custom-assertion")
);
// @formatter:on
}
@Test
public void getTokenResponseWhenBodyExtractorSetThenCalled() {
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = mock(
BodyExtractor.class);
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = mock();
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(bodyExtractor.extract(any(), any())).willReturn(Mono.just(response));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
this.client.setBodyExtractor(bodyExtractor);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.client.getTokenResponse(request).block();
verify(bodyExtractor).extract(any(), any());
}
@Test
public void getTokenResponseWhenClientSecretBasicThenSuccess() throws Exception {
// @formatter:off
String accessTokenResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": 3600,\n"
+ " \"scope\": \"read write\""
+ "}\n";
// @formatter:on
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
enqueueJson(accessTokenResponse);
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
assertThat(response).isNotNull();
assertThat(response.getAccessToken().getScopes()).containsExactly("read", "write");
@ -317,18 +320,12 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenClientSecretPostThenSuccess() throws Exception {
// @formatter:off
String accessTokenResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": 3600,\n"
+ " \"scope\": \"read write\""
+ "}\n";
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
// @formatter:on
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
enqueueJson(accessTokenResponse);
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
assertThat(response).isNotNull();
assertThat(response.getAccessToken().getScopes()).containsExactly("read", "write");
@ -340,17 +337,9 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
@Test
public void getTokenResponseWhenResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception {
// @formatter:off
String accessTokenResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": 3600,\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
enqueueJson(accessTokenResponse);
this.server.enqueue(MockResponses.json("access-token-response-read.json"));
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
assertThat(response).isNotNull();
assertThat(response.getAccessToken().getScopes()).containsExactly("read");
@ -361,7 +350,7 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
throws Exception {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.server.enqueue(MockResponses.json("access-token-response.json"));
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
assertThat(response).isNotNull();
assertThat(response.getAccessToken().getScopes()).isEmpty();
@ -389,12 +378,6 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
.isThrownBy(() -> this.client.getTokenResponse(jwtBearerGrantRequest).block());
}
private void enqueueJson(String body) {
MockResponse response = new MockResponse().setBody(body)
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE);
this.server.enqueue(response);
}
private void enqueueUnexpectedResponse() {
// @formatter:off
MockResponse response = new MockResponse()
@ -414,4 +397,8 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
this.server.enqueue(response);
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.endpoint;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
@ -37,12 +38,14 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
@ -101,14 +104,7 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseWithNoScope()
throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
@ -135,15 +131,7 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponse() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
@ -171,14 +159,7 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
@ -194,14 +175,7 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistrationBuilder
@ -229,14 +203,7 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistrationBuilder
@ -267,14 +234,7 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("invalid-token-type-response.json"));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
@ -287,15 +247,7 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read.json"));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest)
@ -308,12 +260,7 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenErrorResponse = "{\n"
+ " \"error\": \"unauthorized_client\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
this.server.enqueue(MockResponses.json("unauthorized-client-response.json").setResponseCode(400));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
@ -334,14 +281,6 @@ public class WebClientReactivePasswordTokenResponseClientTests {
.withMessageContaining("Empty OAuth 2.0 Access Token Response");
}
private MockResponse jsonResponse(String json) {
// @formatter:off
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
// @formatter:on
}
// gh-10130
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
@ -361,20 +300,12 @@ public class WebClientReactivePasswordTokenResponseClientTests {
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
this.username, this.password);
Converter<OAuth2PasswordGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
Converter<OAuth2PasswordGrantRequest, HttpHeaders> addedHeadersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(addedHeadersConverter.convert(request)).willReturn(headers);
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedHeadersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -389,20 +320,12 @@ public class WebClientReactivePasswordTokenResponseClientTests {
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
this.username, this.password);
ClientRegistration clientRegistration = request.getClientRegistration();
Converter<OAuth2PasswordGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<OAuth2PasswordGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
given(headersConverter.convert(request)).willReturn(headers);
this.tokenResponseClient.setHeadersConverter(headersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(headersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -423,65 +346,75 @@ public class WebClientReactivePasswordTokenResponseClientTests {
}
@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
this.username, this.password);
Converter<OAuth2PasswordGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
Converter<OAuth2PasswordGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(addedParametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.addParametersConverter(addedParametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedParametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=password",
"custom-parameter-name=custom-parameter-value");
String formParameters = actualRequest.getBody().readUtf8();
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "password"),
param("custom-parameter-name", "custom-parameter-value")
);
// @formatter:on
}
@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
this.username, this.password);
Converter<OAuth2PasswordGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2PasswordGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.setParametersConverter(parametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value");
}
@Test
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistrationBuilder.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
this.username, this.password);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.USERNAME, "user");
parameters.set(OAuth2ParameterNames.PASSWORD, "password");
parameters.set(OAuth2ParameterNames.SCOPE, "one two");
this.tokenResponseClient.setParametersConverter((grantRequest) -> parameters);
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
String formParameters = this.server.takeRequest().getBody().readUtf8();
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.CLIENT_ID, "client-id"),
param(OAuth2ParameterNames.SCOPE, "one two"),
param(OAuth2ParameterNames.USERNAME, "user"),
param(OAuth2ParameterNames.PASSWORD, "password")
);
// @formatter:on
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
String accessTokenSuccessResponse = "{}";
WebClientReactivePasswordTokenResponseClient customClient = new WebClientReactivePasswordTokenResponseClient();
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock();
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));
@ -491,11 +424,15 @@ public class WebClientReactivePasswordTokenResponseClientTests {
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
this.username, this.password);
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(passwordGrantRequest).block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.endpoint;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
@ -37,6 +38,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@ -46,6 +48,7 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
@ -105,14 +108,7 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
@ -139,14 +135,7 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
@ -162,14 +151,7 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistrationBuilder
@ -197,14 +179,7 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistrationBuilder
@ -235,14 +210,7 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("invalid-token-type-response.json"));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
@ -254,15 +222,7 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read.json"));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken,
Collections.singleton("read"));
@ -277,12 +237,7 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenErrorResponse = "{\n"
+ " \"error\": \"unauthorized_client\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
this.server.enqueue(MockResponses.json("unauthorized-client-response.json").setResponseCode(400));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
@ -303,14 +258,6 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
.withMessageContaining("Empty OAuth 2.0 Access Token Response");
}
private MockResponse jsonResponse(String json) {
// @formatter:off
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
// @formatter:on
}
// gh-10130
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
@ -327,23 +274,15 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
// gh-10130
@Test
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception {
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> addedHeadersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(addedHeadersConverter.convert(request)).willReturn(headers);
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedHeadersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -354,24 +293,16 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
// gh-10130
@Test
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception {
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
ClientRegistration clientRegistration = request.getClientRegistration();
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> headersConverter1 = mock(Converter.class);
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> headersConverter1 = mock();
HttpHeaders headers = new HttpHeaders();
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
given(headersConverter1.convert(request)).willReturn(headers);
this.tokenResponseClient.setHeadersConverter(headersConverter1);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(headersConverter1).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -392,24 +323,15 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
}
@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(addedParametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.addParametersConverter(addedParametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(addedParametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
@ -418,39 +340,51 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
}
@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(request)).willReturn(parameters);
this.tokenResponseClient.setParametersConverter(parametersConverter);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value");
}
@Test
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistrationBuilder.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, "custom-token");
parameters.set(OAuth2ParameterNames.SCOPE, "one two");
this.tokenResponseClient.setParametersConverter((grantRequest) -> parameters);
this.server.enqueue(MockResponses.json("access-token-response.json"));
this.tokenResponseClient.getTokenResponse(request).block();
String formParameters = this.server.takeRequest().getBody().readUtf8();
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.CLIENT_ID, "client-id"),
param(OAuth2ParameterNames.REFRESH_TOKEN, "custom-token"),
param(OAuth2ParameterNames.SCOPE, "one two")
);
// @formatter:on
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
String accessTokenSuccessResponse = "{}";
WebClientReactiveRefreshTokenTokenResponseClient customClient = new WebClientReactiveRefreshTokenTokenResponseClient();
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock();
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));
@ -459,7 +393,7 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(refreshTokenGrantRequest).block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();
@ -489,4 +423,8 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block());
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}
}

View File

@ -35,6 +35,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.MockResponses;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -89,14 +90,12 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
// @formatter:off
this.clientRegistration = TestClientRegistrations.clientCredentials()
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.tokenUri(tokenUri)
.scope("read", "write");
// @formatter:on
.clientId("client-1")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.tokenUri(tokenUri)
.scope("read", "write");
this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write");
this.actorToken = null;
}
@ -171,15 +170,7 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
@ -210,15 +201,7 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
this.subjectToken = TestJwts.jwt().build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
@ -250,15 +233,7 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
this.actorToken = TestOAuth2AccessTokens.noScopes();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
@ -292,15 +267,7 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read-write.json"));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
this.actorToken = TestJwts.jwt().build();
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
@ -334,14 +301,7 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
this.tokenResponseClient.getTokenResponse(grantRequest).block();
@ -351,14 +311,7 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
@ -367,19 +320,17 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
this.tokenResponseClient.getTokenResponse(grantRequest).block();
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("client_id=client-1", "client_secret=secret");
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.CLIENT_SECRET, "secret")
);
// @formatter:on
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("invalid-token-type-response.json"));
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
// @formatter:off
@ -393,15 +344,7 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response-read.json"));
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block();
@ -411,14 +354,7 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block();
@ -441,8 +377,7 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500));
this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500));
TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
// @formatter:off
@ -455,8 +390,7 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400));
TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
// @formatter:off
@ -497,17 +431,10 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
Converter<TokenExchangeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<TokenExchangeGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -521,17 +448,10 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
Converter<TokenExchangeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
Converter<TokenExchangeGrantRequest, HttpHeaders> headersConverter = mock();
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(headersConverter.convert(grantRequest)).willReturn(headers);
@ -545,17 +465,10 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(MockResponses.json("access-token-response.json"));
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
Converter<TokenExchangeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(Converter.class);
Converter<TokenExchangeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -568,18 +481,34 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
}
@Test
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception {
this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.SCOPE, "one two");
parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN, "custom-token");
this.tokenResponseClient.setParametersConverter((request) -> parameters);
this.server.enqueue(MockResponses.json("access-token-response.json"));
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
Converter<TokenExchangeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(Converter.class);
this.tokenResponseClient.getTokenResponse(grantRequest).block();
String formParameters = this.server.takeRequest().getBody().readUtf8();
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.SCOPE, "one two"),
param(OAuth2ParameterNames.SUBJECT_TOKEN, "custom-token")
);
// @formatter:on
}
@Test
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
Converter<TokenExchangeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add("custom-parameter-name", "custom-parameter-value");
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
@ -602,16 +531,8 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenBodyExtractorSetThenCalled() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = mock(
BodyExtractor.class);
this.server.enqueue(MockResponses.json("access-token-response.json"));
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = mock();
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(bodyExtractor.extract(any(ReactiveHttpInputMessage.class), any(BodyExtractor.Context.class)))
.willReturn(Mono.just(response));
@ -625,15 +546,8 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
@Test
public void getTokenResponseWhenWebClientSetThenCalled() {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
WebClient customClient = mock(WebClient.class);
this.server.enqueue(MockResponses.json("access-token-response.json"));
WebClient customClient = mock();
given(customClient.post()).willReturn(WebClient.builder().build().post());
this.tokenResponseClient.setWebClient(customClient);
ClientRegistration clientRegistration = this.clientRegistration.build();
@ -643,10 +557,6 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
verify(customClient).post();
}
private MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
}
private static String param(String parameterName, String parameterValue) {
return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8));
}

View File

@ -0,0 +1,6 @@
{
"access_token": "access-token-1234",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "create"
}

View File

@ -0,0 +1,9 @@
{
"access_token": "access-token-1234",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "openid profile",
"refresh_token": "refresh-token-1234",
"custom_parameter_1": "custom-value-1",
"custom_parameter_2": "custom-value-2"
}

View File

@ -0,0 +1,6 @@
{
"access_token": "access-token-1234",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "openid profile"
}

View File

@ -0,0 +1,6 @@
{
"access_token": "access-token-1234",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "read write"
}

View File

@ -0,0 +1,6 @@
{
"access_token": "access-token-1234",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "read"
}

View File

@ -1,5 +1,5 @@
{
"access_token": "token",
"access_token": "access-token-1234",
"token_type": "Bearer",
"expires_in": 3600
}

View File

@ -0,0 +1,4 @@
{
"error": "invalid_grant",
"error_description": "Invalid grant"
}

View File

@ -0,0 +1,5 @@
{
"access_token": "access-token-1234",
"token_type": "not-bearer",
"expires_in": 3600
}

View File

@ -0,0 +1,4 @@
{
"error": "server_error",
"error_description": "A server error occurred"
}

View File

@ -0,0 +1,4 @@
{
"error": "unauthorized_client",
"error_description": "Unauthorized client"
}