diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractRestClientOAuth2AccessTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractRestClientOAuth2AccessTokenResponseClient.java new file mode 100644 index 0000000000..94c1942520 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractRestClientOAuth2AccessTokenResponseClient.java @@ -0,0 +1,249 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestClient; +import org.springframework.web.client.RestClient.RequestHeadersSpec; +import org.springframework.web.client.RestClientException; + +/** + * Abstract base class for {@link RestClient}-based implementations of + * {@link OAuth2AccessTokenResponseClient} that communicate to the Authorization Server's + * Token Endpoint. + *

+ * Submits a form request body specific to the type of grant request and accepts a JSON + * response body containing an OAuth 2.0 Access Token Response or OAuth 2.0 Error + * Response. + * + * @param type of grant request + * @author Steve Riesenberg + * @since 6.4 + * @see RFC-6749 Token + * Endpoint + * @see RestClientAuthorizationCodeTokenResponseClient + * @see RestClientClientCredentialsTokenResponseClient + * @see RestClientRefreshTokenTokenResponseClient + * @see RestClientJwtBearerTokenResponseClient + * @see RestClientTokenExchangeTokenResponseClient + * @see DefaultOAuth2TokenRequestHeadersConverter + */ +public abstract class AbstractRestClientOAuth2AccessTokenResponseClient + implements OAuth2AccessTokenResponseClient { + + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; + + // @formatter:off + private RestClient restClient = RestClient.builder() + .messageConverters((messageConverters) -> { + messageConverters.clear(); + messageConverters.add(new FormHttpMessageConverter()); + messageConverters.add(new OAuth2AccessTokenResponseHttpMessageConverter()); + }) + .defaultStatusHandler(new OAuth2ErrorResponseErrorHandler()) + .build(); + // @formatter:on + + private Converter> requestEntityConverter = this::validatingPopulateRequest; + + private Converter headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>(); + + private Converter> parametersConverter = this::createParameters; + + AbstractRestClientOAuth2AccessTokenResponseClient() { + } + + @Override + public OAuth2AccessTokenResponse getTokenResponse(T grantRequest) { + Assert.notNull(grantRequest, "grantRequest cannot be null"); + try { + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = this.requestEntityConverter.convert(grantRequest) + .retrieve() + .body(OAuth2AccessTokenResponse.class); + // @formatter:on + if (accessTokenResponse == null) { + OAuth2Error error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "Empty OAuth 2.0 Access Token Response", null); + throw new OAuth2AuthorizationException(error); + } + return accessTokenResponse; + } + catch (RestClientException ex) { + OAuth2Error error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); + throw new OAuth2AuthorizationException(error, ex); + } + } + + private RequestHeadersSpec validatingPopulateRequest(T grantRequest) { + validateClientAuthenticationMethod(grantRequest); + return populateRequest(grantRequest); + } + + private void validateClientAuthenticationMethod(T grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + ClientAuthenticationMethod clientAuthenticationMethod = clientRegistration.getClientAuthenticationMethod(); + boolean supportedClientAuthenticationMethod = clientAuthenticationMethod.equals(ClientAuthenticationMethod.NONE) + || clientAuthenticationMethod.equals(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + || clientAuthenticationMethod.equals(ClientAuthenticationMethod.CLIENT_SECRET_POST); + if (!supportedClientAuthenticationMethod) { + throw new IllegalArgumentException(String.format( + "This class supports `client_secret_basic`, `client_secret_post`, and `none` by default. Client [%s] is using [%s] instead. Please use a supported client authentication method, or use `set/addParametersConverter` or `set/addHeadersConverter` to supply an instance that supports [%s].", + clientRegistration.getRegistrationId(), clientAuthenticationMethod, clientAuthenticationMethod)); + } + } + + private RequestHeadersSpec populateRequest(T grantRequest) { + return this.restClient.post() + .uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri()) + .headers((headers) -> { + HttpHeaders headersToAdd = this.headersConverter.convert(grantRequest); + if (headersToAdd != null) { + headers.addAll(headersToAdd); + } + }) + .body(this.parametersConverter.convert(grantRequest)); + } + + /** + * Returns a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access + * Token Request body. + * @param grantRequest the authorization grant request + * @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access + * Token Request body + */ + MultiValueMap createParameters(T grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue()); + if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC + .equals(clientRegistration.getClientAuthenticationMethod())) { + parameters.set(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + } + if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) { + parameters.set(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + return parameters; + } + + /** + * Sets the {@link RestClient} used when requesting the OAuth 2.0 Access Token + * Response. + * @param restClient the {@link RestClient} used when requesting the Access Token + * Response + */ + public final void setRestClient(RestClient restClient) { + Assert.notNull(restClient, "restClient cannot be null"); + this.restClient = restClient; + } + + /** + * Sets the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders} + * used in the OAuth 2.0 Access Token Request headers. + * @param headersConverter the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders} + */ + public final void setHeadersConverter(Converter headersConverter) { + Assert.notNull(headersConverter, "headersConverter cannot be null"); + this.headersConverter = headersConverter; + this.requestEntityConverter = this::populateRequest; + } + + /** + * Add (compose) the provided {@code headersConverter} to the current + * {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders} + * used in the OAuth 2.0 Access Token Request headers. + * @param headersConverter the {@link Converter} to add (compose) to the current + * {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link HttpHeaders} + */ + public final void addHeadersConverter(Converter headersConverter) { + Assert.notNull(headersConverter, "headersConverter cannot be null"); + Converter currentHeadersConverter = this.headersConverter; + this.headersConverter = (authorizationGrantRequest) -> { + // Append headers using a Composite Converter + HttpHeaders headers = currentHeadersConverter.convert(authorizationGrantRequest); + if (headers == null) { + headers = new HttpHeaders(); + } + HttpHeaders headersToAdd = headersConverter.convert(authorizationGrantRequest); + if (headersToAdd != null) { + headers.addAll(headersToAdd); + } + return headers; + }; + this.requestEntityConverter = this::populateRequest; + } + + /** + * Sets the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} + * used in the OAuth 2.0 Access Token Request body. + * @param parametersConverter the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} to {@link MultiValueMap} + */ + public final void setParametersConverter(Converter> parametersConverter) { + Assert.notNull(parametersConverter, "parametersConverter cannot be null"); + this.parametersConverter = parametersConverter; + this.requestEntityConverter = this::populateRequest; + } + + /** + * Add (compose) the provided {@code parametersConverter} to the current + * {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} + * used in the OAuth 2.0 Access Token Request body. + * @param parametersConverter the {@link Converter} to add (compose) to the current + * {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link MultiValueMap} + */ + public final void addParametersConverter(Converter> parametersConverter) { + Assert.notNull(parametersConverter, "parametersConverter cannot be null"); + Converter> currentParametersConverter = this.parametersConverter; + this.parametersConverter = (authorizationGrantRequest) -> { + MultiValueMap parameters = currentParametersConverter.convert(authorizationGrantRequest); + if (parameters == null) { + parameters = new LinkedMultiValueMap<>(); + } + MultiValueMap parametersToAdd = parametersConverter.convert(authorizationGrantRequest); + if (parametersToAdd != null) { + parameters.addAll(parametersToAdd); + } + return parameters; + }; + this.requestEntityConverter = this::populateRequest; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClient.java new file mode 100644 index 0000000000..a63d997a9a --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClient.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.util.MultiValueMap; + +/** + * An implementation of {@link OAuth2AccessTokenResponseClient} that "exchanges" + * an authorization code for an access token at the Authorization Server's Token Endpoint. + * + * @author Steve Riesenberg + * @since 6.4 + * @see OAuth2AccessTokenResponseClient + * @see OAuth2AuthorizationCodeGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section 4.1.3 Access Token Request + * (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response + * (Authorization Code Grant) + * @see Section + * 4.2 Client Creates the Code Challenge + */ +public final class RestClientAuthorizationCodeTokenResponseClient + extends AbstractRestClientOAuth2AccessTokenResponseClient { + + @Override + MultiValueMap createParameters(OAuth2AuthorizationCodeGrantRequest grantRequest) { + OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange(); + MultiValueMap parameters = super.createParameters(grantRequest); + parameters.set(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode()); + String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri(); + if (redirectUri != null) { + parameters.set(OAuth2ParameterNames.REDIRECT_URI, redirectUri); + } + String codeVerifier = authorizationExchange.getAuthorizationRequest() + .getAttribute(PkceParameterNames.CODE_VERIFIER); + if (codeVerifier != null) { + parameters.set(PkceParameterNames.CODE_VERIFIER, codeVerifier); + } + return parameters; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClient.java new file mode 100644 index 0000000000..7aa896e913 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClient.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * An implementation of {@link OAuth2AccessTokenResponseClient} that "exchanges" + * client credentials for an access token at the Authorization Server's Token Endpoint. + * + * @author Steve Riesenberg + * @since 6.4 + * @see OAuth2AccessTokenResponseClient + * @see OAuth2ClientCredentialsGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section 4.1.3 Access Token Request + * (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response + * (Authorization Code Grant) + */ +public final class RestClientClientCredentialsTokenResponseClient + extends AbstractRestClientOAuth2AccessTokenResponseClient { + + @Override + MultiValueMap createParameters(OAuth2ClientCredentialsGrantRequest grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = super.createParameters(grantRequest); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + parameters.set(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + return parameters; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClient.java new file mode 100644 index 0000000000..6510241067 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClient.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * An implementation of {@link OAuth2AccessTokenResponseClient} that "exchanges" + * a JWT for an access token at the Authorization Server's Token Endpoint. + * + * @author Steve Riesenberg + * @since 6.4 + * @see OAuth2AccessTokenResponseClient + * @see JwtBearerGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section + * 2.1 Using JWTs as Authorization Grants + * @see Section + * 4.1 Using Assertions as Authorization Grants + */ +public final class RestClientJwtBearerTokenResponseClient + extends AbstractRestClientOAuth2AccessTokenResponseClient { + + @Override + MultiValueMap createParameters(JwtBearerGrantRequest grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = super.createParameters(grantRequest); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + parameters.set(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + parameters.set(OAuth2ParameterNames.ASSERTION, grantRequest.getJwt().getTokenValue()); + return parameters; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClient.java new file mode 100644 index 0000000000..02519ca8aa --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClient.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * An implementation of {@link OAuth2AccessTokenResponseClient} that "exchanges" + * a refresh token for an access token at the Authorization Server's Token Endpoint. + * + * @author Steve Riesenberg + * @since 6.4 + * @see OAuth2AccessTokenResponseClient + * @see OAuth2RefreshTokenGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section 6 + * Refreshing an Access Token + */ +public final class RestClientRefreshTokenTokenResponseClient + extends AbstractRestClientOAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest) { + OAuth2AccessTokenResponse accessTokenResponse = super.getTokenResponse(grantRequest); + return populateTokenResponse(grantRequest, accessTokenResponse); + } + + @Override + MultiValueMap createParameters(OAuth2RefreshTokenGrantRequest grantRequest) { + MultiValueMap parameters = super.createParameters(grantRequest); + if (!CollectionUtils.isEmpty(grantRequest.getScopes())) { + parameters.set(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(grantRequest.getScopes(), " ")); + } + parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, grantRequest.getRefreshToken().getTokenValue()); + return parameters; + } + + private OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest, + OAuth2AccessTokenResponse accessTokenResponse) { + if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes()) + && accessTokenResponse.getRefreshToken() != null) { + return accessTokenResponse; + } + OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse + .withResponse(accessTokenResponse); + if (CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())) { + tokenResponseBuilder.scopes(grantRequest.getAccessToken().getScopes()); + } + if (accessTokenResponse.getRefreshToken() == null) { + // Reuse existing refresh token + tokenResponseBuilder.refreshToken(grantRequest.getRefreshToken().getTokenValue()); + } + return tokenResponseBuilder.build(); + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClient.java new file mode 100644 index 0000000000..e0e6544ad9 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClient.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * An implementation of {@link OAuth2AccessTokenResponseClient} that "exchanges" + * a subject token (and optionally an actor token) for an access token at the + * Authorization Server's Token Endpoint. + * + * @author Steve Riesenberg + * @since 6.4 + * @see OAuth2AccessTokenResponseClient + * @see TokenExchangeGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section + * 2.1 Request + * @see Section + * 2.2 Response + */ +public final class RestClientTokenExchangeTokenResponseClient + extends AbstractRestClientOAuth2AccessTokenResponseClient { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + @Override + MultiValueMap createParameters(TokenExchangeGrantRequest grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = super.createParameters(grantRequest); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + parameters.set(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + parameters.set(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE); + OAuth2Token subjectToken = grantRequest.getSubjectToken(); + parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue()); + parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken)); + OAuth2Token actorToken = grantRequest.getActorToken(); + if (actorToken != null) { + parameters.set(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue()); + parameters.set(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken)); + } + return parameters; + } + + private static String tokenType(OAuth2Token token) { + return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClientTests.java new file mode 100644 index 0000000000..f365a93291 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClientTests.java @@ -0,0 +1,511 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Collections; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link RestClientAuthorizationCodeTokenResponseClient}. + * + * @author Steve Riesenberg + */ +public class RestClientAuthorizationCodeTokenResponseClientTests { + + private RestClientAuthorizationCodeTokenResponseClient tokenResponseClient; + + private MockWebServer server; + + private ClientRegistration.Builder clientRegistration; + + private OAuth2AuthorizationExchange authorizationExchange; + + @BeforeEach + public void setUp() throws IOException { + this.tokenResponseClient = new RestClientAuthorizationCodeTokenResponseClient(); + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off + this.clientRegistration = TestClientRegistrations.clientRegistration() + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .tokenUri(tokenUri) + .scope("read", "write"); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .clientId("client-1") + .state("state") + .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) + .redirectUri(clientRegistration.getRedirectUri()) + .scopes(clientRegistration.getScopes()) + .build(); + OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse.success("code") + .state("state") + .redirectUri(clientRegistration.getRedirectUri()) + .build(); + // @formatter:on + this.authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); + } + + @AfterEach + public void cleanUp() throws IOException { + this.server.shutdown(); + } + + @Test + public void setRestClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRestClient(null)) + .withMessage("restClient cannot be null"); + // @formatter:on + } + + @Test + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .withMessage("grantRequest cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()), + param(OAuth2ParameterNames.CODE, "code") + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .havingRootCause().withMessage("tokenType cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).isEmpty(); + } + + @Test + public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(301)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest request = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessage("[invalid_token_response] Empty OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest request = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest request = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT)) + .withMessage("[invalid_grant] Invalid grant"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(new ClientAuthenticationMethod("basic")) + .build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.addHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.setHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + Converter> parametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + Converter> parametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.addParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()), + param(OAuth2ParameterNames.CODE, "code"), + param("custom-parameter-name", "custom-parameter-value") + ); + // @formatter:on + } + + @Test + public void getTokenResponseWhenRestClientSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + RestClient restClient = RestClient.builder().messageConverters((messageConverters) -> { + messageConverters.add(0, new FormHttpMessageConverter()); + messageConverters.add(1, new OAuth2AccessTokenResponseHttpMessageConverter()); + }).build(); + RestClient customClient = spy(restClient); + this.tokenResponseClient.setRestClient(customClient); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(customClient).post(); + } + + private static MockResponse jsonResponse(String json) { + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); + } + + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClientTests.java new file mode 100644 index 0000000000..c97a02ca1b --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClientTests.java @@ -0,0 +1,518 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Collections; +import java.util.Set; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.client.RestClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link RestClientClientCredentialsTokenResponseClient}. + * + * @author Steve Riesenberg + */ +public class RestClientClientCredentialsTokenResponseClientTests { + + private RestClientClientCredentialsTokenResponseClient tokenResponseClient; + + private MockWebServer server; + + private ClientRegistration.Builder clientRegistration; + + @BeforeEach + public void setUp() throws IOException { + this.tokenResponseClient = new RestClientClientCredentialsTokenResponseClient(); + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off + this.clientRegistration = TestClientRegistrations.clientCredentials() + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .tokenUri(tokenUri) + .scope("read", "write"); + // @formatter:on + } + + @AfterEach + public void cleanUp() throws IOException { + this.server.shutdown(); + } + + @Test + public void setRestClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRestClient(null)) + .withMessage("restClient cannot be null"); + // @formatter:on + } + + @Test + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .withMessage("grantRequest cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .havingRootCause().withMessage("tokenType cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).isEmpty(); + } + + @Test + public void getTokenResponseWhenRequestDoesNotIncludeScopeThenAccessTokenHasNoScope() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + // @formatter:off + ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("no-scope") + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .tokenUri(this.server.url("/oauth2/token").toString()) + .build(); + // @formatter:on + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + ); + // @formatter:on + assertThat(formParameters).doesNotContain(OAuth2ParameterNames.SCOPE); + assertThat(accessTokenResponse.getAccessToken().getScopes()).isEmpty(); + } + + @Test + public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(301)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessage("[invalid_token_response] Empty OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT)) + .withMessage("[invalid_grant] Invalid grant"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(new ClientAuthenticationMethod("basic")) + .build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.addHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.setHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + Converter> parametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + Converter> parametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.addParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")), + param("custom-parameter-name", "custom-parameter-value") + ); + // @formatter:on + } + + @Test + public void getTokenResponseWhenRestClientSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + RestClient restClient = RestClient.builder().messageConverters((messageConverters) -> { + messageConverters.add(0, new FormHttpMessageConverter()); + messageConverters.add(1, new OAuth2AccessTokenResponseHttpMessageConverter()); + }).build(); + RestClient customClient = spy(restClient); + this.tokenResponseClient.setRestClient(customClient); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(customClient).post(); + } + + private static MockResponse jsonResponse(String json) { + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); + } + + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClientTests.java new file mode 100644 index 0000000000..db8c822888 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClientTests.java @@ -0,0 +1,480 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Collections; +import java.util.Set; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.client.RestClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link RestClientJwtBearerTokenResponseClient}. + * + * @author Steve Riesenberg + */ +public class RestClientJwtBearerTokenResponseClientTests { + + private RestClientJwtBearerTokenResponseClient tokenResponseClient; + + private MockWebServer server; + + private ClientRegistration.Builder clientRegistration; + + private Jwt jwtAssertion; + + @BeforeEach + public void setUp() throws IOException { + this.tokenResponseClient = new RestClientJwtBearerTokenResponseClient(); + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off + this.clientRegistration = TestClientRegistrations.clientCredentials() + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) + .tokenUri(tokenUri) + .scope("read", "write"); + // @formatter:on + this.jwtAssertion = TestJwts.jwt().build(); + } + + @AfterEach + public void cleanUp() throws IOException { + this.server.shutdown(); + } + + @Test + public void setRestClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRestClient(null)) + .withMessage("restClient cannot be null"); + // @formatter:on + } + + @Test + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .withMessage("grantRequest cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.JWT_BEARER.getValue()), + param(OAuth2ParameterNames.ASSERTION, this.jwtAssertion.getTokenValue()), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .havingRootCause().withMessage("tokenType cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).isEmpty(); + } + + @Test + public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(301)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessage("[invalid_token_response] Empty OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT)) + .withMessage("[invalid_grant] Invalid grant"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(new ClientAuthenticationMethod("basic")) + .build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.addHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.setHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + Converter> parametersConverter = mock(Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + Converter> parametersConverter = mock(Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.addParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.JWT_BEARER.getValue()), + param(OAuth2ParameterNames.ASSERTION, this.jwtAssertion.getTokenValue()), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")), + param("custom-parameter-name", "custom-parameter-value") + ); + // @formatter:on + } + + @Test + public void getTokenResponseWhenRestClientSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + RestClient customClient = mock(RestClient.class); + given(customClient.post()).willReturn(RestClient.builder().build().post()); + this.tokenResponseClient.setRestClient(customClient); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(customClient).post(); + } + + private static MockResponse jsonResponse(String json) { + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); + } + + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClientTests.java new file mode 100644 index 0000000000..bb14848cd2 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClientTests.java @@ -0,0 +1,541 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Collections; +import java.util.Set; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.client.RestClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link RestClientRefreshTokenTokenResponseClient}. + * + * @author Steve Riesenberg + */ +public class RestClientRefreshTokenTokenResponseClientTests { + + private RestClientRefreshTokenTokenResponseClient tokenResponseClient; + + private MockWebServer server; + + private ClientRegistration.Builder clientRegistration; + + private OAuth2AccessToken accessToken; + + private OAuth2RefreshToken refreshToken; + + @BeforeEach + public void setUp() throws IOException { + this.tokenResponseClient = new RestClientRefreshTokenTokenResponseClient(); + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off + this.clientRegistration = TestClientRegistrations.clientCredentials() + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) + .tokenUri(tokenUri) + .scope("read", "write"); + // @formatter:on + this.accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.refreshToken = TestOAuth2RefreshTokens.refreshToken(); + } + + @AfterEach + public void cleanUp() throws IOException { + this.server.shutdown(); + } + + @Test + public void setRestClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRestClient(null)) + .withMessage("restClient cannot be null"); + // @formatter:on + } + + @Test + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .withMessage("grantRequest cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken, scopes); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue()), + param(OAuth2ParameterNames.REFRESH_TOKEN, this.refreshToken.getTokenValue()), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(this.refreshToken.getTokenValue()); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .havingRootCause().withMessage("tokenType cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasRequestedScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken, scopes); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + } + + @Test + public void getTokenResponseWhenRequestDoesNotIncludeScopeThenAccessTokenHasResponseScope() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue()), + param(OAuth2ParameterNames.REFRESH_TOKEN, this.refreshToken.getTokenValue()) + ); + // @formatter:on + assertThat(formParameters).doesNotContain(OAuth2ParameterNames.SCOPE); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(301)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessage("[invalid_token_response] Empty OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT)) + .withMessage("[invalid_grant] Invalid grant"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(new ClientAuthenticationMethod("basic")) + .build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.addHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.setHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + Converter> parametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken, scopes); + Converter> parametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.addParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue()), + param(OAuth2ParameterNames.REFRESH_TOKEN, this.refreshToken.getTokenValue()), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")), + param("custom-parameter-name", "custom-parameter-value") + ); + // @formatter:on + } + + @Test + public void getTokenResponseWhenRestClientSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + RestClient restClient = RestClient.builder().messageConverters((messageConverters) -> { + messageConverters.add(0, new FormHttpMessageConverter()); + messageConverters.add(1, new OAuth2AccessTokenResponseHttpMessageConverter()); + }).build(); + RestClient customClient = spy(restClient); + this.tokenResponseClient.setRestClient(customClient); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(customClient).post(); + } + + private static MockResponse jsonResponse(String json) { + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); + } + + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClientTests.java new file mode 100644 index 0000000000..deb448e248 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClientTests.java @@ -0,0 +1,635 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Collections; +import java.util.Set; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.client.RestClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link RestClientTokenExchangeTokenResponseClient}. + * + * @author Steve Riesenberg + */ +public class RestClientTokenExchangeTokenResponseClientTests { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + private RestClientTokenExchangeTokenResponseClient tokenResponseClient; + + private MockWebServer server; + + private ClientRegistration.Builder clientRegistration; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + @BeforeEach + public void setUp() throws IOException { + this.tokenResponseClient = new RestClientTokenExchangeTokenResponseClient(); + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off + this.clientRegistration = TestClientRegistrations.clientCredentials() + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .tokenUri(tokenUri) + .scope("read", "write"); + // @formatter:on + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = null; + } + + @AfterEach + public void cleanUp() throws IOException { + this.server.shutdown(); + } + + @Test + public void setRestClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRestClient(null)) + .withMessage("restClient cannot be null"); + // @formatter:on + } + + @Test + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .withMessage("grantRequest cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.subjectToken = TestJwts.jwt().build(); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, JWT_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.actorToken = TestOAuth2AccessTokens.noScopes(); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.ACTOR_TOKEN, this.actorToken.getTokenValue()), + param(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.actorToken = TestJwts.jwt().build(); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.ACTOR_TOKEN, this.actorToken.getTokenValue()), + param(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, JWT_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .havingRootCause().withMessage("tokenType cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).isEmpty(); + } + + @Test + public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(301)); + TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessage("[invalid_token_response] Empty OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT)) + .withMessage("[invalid_grant] Invalid grant"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(new ClientAuthenticationMethod("basic")) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.addHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.setHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + Converter> parametersConverter = mock(Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + Set scopes = clientRegistration.getScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + Converter> parametersConverter = mock(Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.addParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")), + param("custom-parameter-name", "custom-parameter-value") + ); + // @formatter:on + } + + @Test + public void getTokenResponseWhenRestClientSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + RestClient customClient = mock(RestClient.class); + given(customClient.post()).willReturn(RestClient.builder().build().post()); + this.tokenResponseClient.setRestClient(customClient); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(customClient).post(); + } + + private static MockResponse jsonResponse(String json) { + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); + } + + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + +}