diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 982f660f6e..baa697a009 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -20,6 +20,7 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.util.ClassUtils; @@ -57,17 +58,26 @@ final class OAuth2ClientConfiguration { @Configuration static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer { + private ClientRegistrationRepository clientRegistrationRepository; private OAuth2AuthorizedClientRepository authorizedClientRepository; @Override public void addArgumentResolvers(List argumentResolvers) { - if (this.authorizedClientRepository != null) { + if (this.clientRegistrationRepository != null && this.authorizedClientRepository != null) { OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver = - new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository); + new OAuth2AuthorizedClientArgumentResolver( + this.clientRegistrationRepository, this.authorizedClientRepository); argumentResolvers.add(authorizedClientArgumentResolver); } } + @Autowired(required = false) + public void setClientRegistrationRepository(List clientRegistrationRepositories) { + if (clientRegistrationRepositories.size() == 1) { + this.clientRegistrationRepository = clientRegistrationRepositories.get(0); + } + } + @Autowired(required = false) public void setAuthorizedClientRepository(List authorizedClientRepositories) { if (authorizedClientRepositories.size() == 1) { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java index 524f4c57d8..eb15b4ef51 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java @@ -98,6 +98,11 @@ public class OAuth2ClientConfigurationTests { } } + @Bean + public ClientRegistrationRepository clientRegistrationRepository() { + return mock(ClientRegistrationRepository.class); + } + @Bean public OAuth2AuthorizedClientRepository authorizedClientRepository() { return AUTHORIZED_CLIENT_REPOSITORY; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java new file mode 100644 index 0000000000..f99409a276 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java @@ -0,0 +1,270 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.util.UriComponentsBuilder; + +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * The default implementation of an {@link OAuth2AccessTokenResponseClient} + * for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. + * This implementation uses a {@link RestOperations} when requesting + * an access token credential at the Authorization Server's Token Endpoint. + * + * @author Joe Grandja + * @since 5.1 + * @see OAuth2AccessTokenResponseClient + * @see OAuth2ClientCredentialsGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section 4.4.2 Access Token Request (Client Credentials Grant) + * @see Section 4.4.3 Access Token Response (Client Credentials Grant) + */ +public class DefaultClientCredentialsTokenResponseClient implements OAuth2AccessTokenResponseClient { + private static final String INVALID_TOKEN_REQUEST_ERROR_CODE = "invalid_token_request"; + + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; + + private static final String[] TOKEN_RESPONSE_PARAMETER_NAMES = { + OAuth2ParameterNames.ACCESS_TOKEN, + OAuth2ParameterNames.TOKEN_TYPE, + OAuth2ParameterNames.EXPIRES_IN, + OAuth2ParameterNames.SCOPE, + OAuth2ParameterNames.REFRESH_TOKEN + }; + + private RestOperations restOperations; + + public DefaultClientCredentialsTokenResponseClient() { + RestTemplate restTemplate = new RestTemplate(); + // Disable the ResponseErrorHandler as errors are handled directly within this class + restTemplate.setErrorHandler(new NoOpResponseErrorHandler()); + this.restOperations = restTemplate; + } + + @Override + public OAuth2AccessTokenResponse getTokenResponse(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) + throws OAuth2AuthenticationException { + + Assert.notNull(clientCredentialsGrantRequest, "clientCredentialsGrantRequest cannot be null"); + + // Build request + RequestEntity> request = this.buildRequest(clientCredentialsGrantRequest); + + // Exchange + ResponseEntity> response; + try { + response = this.restOperations.exchange( + request, new ParameterizedTypeReference>() {}); + } catch (Exception ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_REQUEST_ERROR_CODE, + "An error occurred while sending the Access Token Request: " + ex.getMessage(), null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); + } + + Map responseParameters = response.getBody(); + + // Check for Error Response + if (response.getStatusCodeValue() != 200) { + OAuth2Error oauth2Error = this.parseErrorResponse(responseParameters); + if (oauth2Error == null) { + oauth2Error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR); + } + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + + // Success Response + OAuth2AccessTokenResponse tokenResponse; + try { + tokenResponse = this.parseTokenResponse(responseParameters); + } catch (Exception ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred parsing the Access Token response (200 OK): " + ex.getMessage(), null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); + } + + if (tokenResponse == null) { + // This should never happen as long as the provider + // implements a Successful Response as defined in Section 5.1 + // https://tools.ietf.org/html/rfc6749#section-5.1 + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred parsing the Access Token response (200 OK). " + + "Missing required parameters: access_token and/or token_type", null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + + if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { + // As per spec, in Section 5.1 Successful Access Token Response + // https://tools.ietf.org/html/rfc6749#section-5.1 + // If AccessTokenResponse.scope is empty, then default to the scope + // originally requested by the client in the Token Request + tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) + .scopes(clientCredentialsGrantRequest.getClientRegistration().getScopes()) + .build(); + } + + return tokenResponse; + } + + private RequestEntity> buildRequest(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { + HttpHeaders headers = this.buildHeaders(clientCredentialsGrantRequest); + MultiValueMap formParameters = this.buildFormParameters(clientCredentialsGrantRequest); + URI uri = UriComponentsBuilder.fromUriString(clientCredentialsGrantRequest.getClientRegistration().getProviderDetails().getTokenUri()) + .build() + .toUri(); + + return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); + } + + private HttpHeaders buildHeaders(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { + ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); + + HttpHeaders headers = new HttpHeaders(); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); + headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); + if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { + headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); + } + + return headers; + } + + private MultiValueMap buildFormParameters(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { + ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); + + MultiValueMap formParameters = new LinkedMultiValueMap<>(); + formParameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue()); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + formParameters.add(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { + formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + + return formParameters; + } + + private OAuth2Error parseErrorResponse(Map responseParameters) { + if (CollectionUtils.isEmpty(responseParameters) || + !responseParameters.containsKey(OAuth2ParameterNames.ERROR)) { + return null; + } + + String errorCode = responseParameters.get(OAuth2ParameterNames.ERROR); + String errorDescription = responseParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION); + String errorUri = responseParameters.get(OAuth2ParameterNames.ERROR_URI); + + return new OAuth2Error(errorCode, errorDescription, errorUri); + } + + private OAuth2AccessTokenResponse parseTokenResponse(Map responseParameters) { + if (CollectionUtils.isEmpty(responseParameters) || + !responseParameters.containsKey(OAuth2ParameterNames.ACCESS_TOKEN) || + !responseParameters.containsKey(OAuth2ParameterNames.TOKEN_TYPE)) { + return null; + } + + String accessToken = responseParameters.get(OAuth2ParameterNames.ACCESS_TOKEN); + + OAuth2AccessToken.TokenType accessTokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( + responseParameters.get(OAuth2ParameterNames.TOKEN_TYPE))) { + accessTokenType = OAuth2AccessToken.TokenType.BEARER; + } + + long expiresIn = 0; + if (responseParameters.containsKey(OAuth2ParameterNames.EXPIRES_IN)) { + try { + expiresIn = Long.valueOf(responseParameters.get(OAuth2ParameterNames.EXPIRES_IN)); + } catch (NumberFormatException ex) { } + } + + Set scopes = Collections.emptySet(); + if (responseParameters.containsKey(OAuth2ParameterNames.SCOPE)) { + String scope = responseParameters.get(OAuth2ParameterNames.SCOPE); + scopes = Arrays.stream(StringUtils.delimitedListToStringArray(scope, " ")).collect(Collectors.toSet()); + } + + Map additionalParameters = new LinkedHashMap<>(); + Set tokenResponseParameterNames = Stream.of(TOKEN_RESPONSE_PARAMETER_NAMES).collect(Collectors.toSet()); + responseParameters.entrySet().stream() + .filter(e -> !tokenResponseParameterNames.contains(e.getKey())) + .forEach(e -> additionalParameters.put(e.getKey(), e.getValue())); + + return OAuth2AccessTokenResponse.withToken(accessToken) + .tokenType(accessTokenType) + .expiresIn(expiresIn) + .scopes(scopes) + .additionalParameters(additionalParameters) + .build(); + } + + /** + * Sets the {@link RestOperations} used when requesting the access token response. + * + * @param restOperations the {@link RestOperations} used when requesting the access token response + */ + public final void setRestOperations(RestOperations restOperations) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.restOperations = restOperations; + } + + private static class NoOpResponseErrorHandler implements ResponseErrorHandler { + + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + return false; + } + + @Override + public void handleError(ClientHttpResponse response) throws IOException { + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java new file mode 100644 index 0000000000..9f62c8671c --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.util.Assert; + +/** + * An OAuth 2.0 Client Credentials Grant request that holds + * the client's credentials in {@link #getClientRegistration()}. + * + * @author Joe Grandja + * @since 5.1 + * @see AbstractOAuth2AuthorizationGrantRequest + * @see ClientRegistration + * @see Section 1.3.4 Client Credentials Grant + */ +public class OAuth2ClientCredentialsGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { + private final ClientRegistration clientRegistration; + + /** + * Constructs an {@code OAuth2ClientCredentialsGrantRequest} using the provided parameters. + * + * @param clientRegistration the client registration + */ + public OAuth2ClientCredentialsGrantRequest(ClientRegistration clientRegistration) { + super(AuthorizationGrantType.CLIENT_CREDENTIALS); + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + Assert.isTrue(AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()), + "clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS"); + this.clientRegistration = clientRegistration; + } + + /** + * Returns the {@link ClientRegistration client registration}. + * + * @return the {@link ClientRegistration} + */ + public ClientRegistration getClientRegistration() { + return this.clientRegistration; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java index 9783ff44c1..f6342fc06a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java @@ -448,7 +448,9 @@ public final class ClientRegistration { */ public ClientRegistration build() { Assert.notNull(this.authorizationGrantType, "authorizationGrantType cannot be null"); - if (AuthorizationGrantType.IMPLICIT.equals(this.authorizationGrantType)) { + if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(this.authorizationGrantType)) { + this.validateClientCredentialsGrantType(); + } else if (AuthorizationGrantType.IMPLICIT.equals(this.authorizationGrantType)) { this.validateImplicitGrantType(); } else { this.validateAuthorizationCodeGrantType(); @@ -507,5 +509,15 @@ public final class ClientRegistration { Assert.hasText(this.authorizationUri, "authorizationUri cannot be empty"); Assert.hasText(this.clientName, "clientName cannot be empty"); } + + private void validateClientCredentialsGrantType() { + Assert.isTrue(AuthorizationGrantType.CLIENT_CREDENTIALS.equals(this.authorizationGrantType), + () -> "authorizationGrantType must be " + AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()); + Assert.hasText(this.registrationId, "registrationId cannot be empty"); + Assert.hasText(this.clientId, "clientId cannot be empty"); + Assert.hasText(this.clientSecret, "clientSecret cannot be empty"); + Assert.notNull(this.clientAuthenticationMethod, "clientAuthenticationMethod cannot be null"); + Assert.hasText(this.tokenUri, "tokenUri cannot be empty"); + } } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 67daf2955b..91d7db916b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -25,7 +25,14 @@ import org.springframework.security.oauth2.client.ClientAuthorizationRequiredExc import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.bind.support.WebDataBinderFactory; @@ -34,6 +41,7 @@ import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.method.support.ModelAndViewContainer; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; /** * An implementation of a {@link HandlerMethodArgumentResolver} that is capable @@ -56,15 +64,22 @@ import javax.servlet.http.HttpServletRequest; * @see RegisteredOAuth2AuthorizedClient */ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver { + private final ClientRegistrationRepository clientRegistrationRepository; private final OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = + new DefaultClientCredentialsTokenResponseClient(); /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. * - * @param authorizedClientRepository the authorized client repository + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients */ - public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientRepository authorizedClientRepository) { + public OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; } @@ -83,8 +98,43 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception { + String clientRegistrationId = this.resolveClientRegistrationId(parameter); + if (StringUtils.isEmpty(clientRegistrationId)) { + throw new IllegalArgumentException("Unable to resolve the Client Registration Identifier. " + + "It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or " + + "@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); + } + + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + clientRegistrationId, principal, servletRequest); + if (authorizedClient != null) { + return authorizedClient; + } + + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + if (clientRegistration == null) { + return null; + } + + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { + throw new ClientAuthorizationRequiredException(clientRegistrationId); + } + + if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); + authorizedClient = this.authorizeClientCredentialsClient(clientRegistration, servletRequest, servletResponse); + } + + return authorizedClient; + } + + private String resolveClientRegistrationId(MethodParameter parameter) { RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils.findMergedAnnotation( parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class); + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); String clientRegistrationId = null; @@ -95,17 +145,41 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth } else if (principal != null && OAuth2AuthenticationToken.class.isAssignableFrom(principal.getClass())) { clientRegistrationId = ((OAuth2AuthenticationToken) principal).getAuthorizedClientRegistrationId(); } - if (StringUtils.isEmpty(clientRegistrationId)) { - throw new IllegalArgumentException("Unable to resolve the Client Registration Identifier. " + - "It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); - } - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - clientRegistrationId, principal, webRequest.getNativeRequest(HttpServletRequest.class)); - if (authorizedClient == null) { - throw new ClientAuthorizationRequiredException(clientRegistrationId); - } + return clientRegistrationId; + } + + private OAuth2AuthorizedClient authorizeClientCredentialsClient(ClientRegistration clientRegistration, + HttpServletRequest request, HttpServletResponse response) { + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(clientRegistration); + OAuth2AccessTokenResponse tokenResponse = + this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, + (principal != null ? principal.getName() : "anonymousUser"), + tokenResponse.getAccessToken()); + + this.authorizedClientRepository.saveAuthorizedClient( + authorizedClient, + principal, + request, + response); return authorizedClient; } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. + * + * @param clientCredentialsTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant + */ + public final void setClientCredentialsTokenResponseClient( + OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { + Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); + this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java new file mode 100644 index 0000000000..117a17cb34 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java @@ -0,0 +1,326 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; + +import java.time.Instant; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link DefaultClientCredentialsTokenResponseClient}. + * + * @author Joe Grandja + */ +public class DefaultClientCredentialsTokenResponseClientTests { + private DefaultClientCredentialsTokenResponseClient tokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); + private ClientRegistration clientRegistration; + private MockWebServer server; + + @Before + public void setup() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + + String tokenUri = this.server.url("/oauth2/token").toString(); + + this.clientRegistration = ClientRegistration.withRegistrationId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .scope("read", "write") + .tokenUri(tokenUri) + .build(); + } + + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + + @Test + public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.setRestOperations(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).startsWith(MediaType.APPLICATION_FORM_URLENCODED.toString()); + + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("grant_type=client_credentials"); + assertThat(formParameters).contains("scope=read+write"); + + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + assertThat(accessTokenResponse.getAdditionalParameters().size()).isEqualTo(2); + assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_1", "custom-value-1"); + assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_2", "custom-value-2"); + } + + @Test + public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); + + this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + } + + @Test + public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + ClientRegistration clientRegistration = this.from(this.clientRegistration) + .clientAuthenticationMethod(ClientAuthenticationMethod.POST) + .build(); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(clientRegistration); + + this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-1"); + assertThat(formParameters).contains("client_secret=secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthenticationException() { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response (200 OK): tokenType cannot be null"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndMissingTokenTypeParameterThenThrowOAuth2AuthenticationException() { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response (200 OK). Missing required parameters: access_token and/or token_type"); + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasDefaultScope() { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write"); + } + + @Test + public void getTokenResponseWhenTokenUriMalformedThenThrowOAuth2AuthenticationException() { + String malformedTokenUri = "http:\\provider.com\\oauth2\\token"; + ClientRegistration clientRegistration = this.from(this.clientRegistration) + .tokenUri(malformedTokenUri) + .build(); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(clientRegistration); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("[invalid_token_request] An error occurred while sending the Access Token Request:"); + } + + @Test + public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthenticationException() { + String invalidTokenUri = "http://invalid-provider.com/oauth2/token"; + ClientRegistration clientRegistration = this.from(this.clientRegistration) + .tokenUri(invalidTokenUri) + .build(); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(clientRegistration); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("[invalid_token_request] An error occurred while sending the Access Token Request:"); + } + + @Test + public void getTokenResponseWhenMalformedResponseThenThrowOAuth2AuthenticationException() { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n"; +// "}\n"; // Make the JSON invalid/malformed + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("[invalid_token_request] An error occurred while sending the Access Token Request:"); + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthenticationException() { + String accessTokenErrorResponse = "{\n" + + " \"error\": \"unauthorized_client\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("[unauthorized_client]"); + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthenticationException() { + this.server.enqueue(new MockResponse().setResponseCode(500)); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("[server_error]"); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } + + private ClientRegistration.Builder from(ClientRegistration registration) { + return ClientRegistration.withRegistrationId(registration.getRegistrationId()) + .clientId(registration.getClientId()) + .clientSecret(registration.getClientSecret()) + .clientAuthenticationMethod(registration.getClientAuthenticationMethod()) + .authorizationGrantType(registration.getAuthorizationGrantType()) + .scope(registration.getScopes()) + .tokenUri(registration.getProviderDetails().getTokenUri()); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestTests.java new file mode 100644 index 0000000000..47e9013248 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestTests.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Java6Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2ClientCredentialsGrantRequest}. + * + * @author Joe Grandja + */ +public class OAuth2ClientCredentialsGrantRequestTests { + private ClientRegistration clientRegistration; + + @Before + public void setup() { + this.clientRegistration = ClientRegistration.withRegistrationId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .scope("read", "write") + .tokenUri("https://provider.com/oauth2/token") + .build(); + } + + @Test + public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientCredentialsGrantRequest(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorWhenClientRegistrationInvalidGrantTypeThenThrowIllegalArgumentException() { + ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1") + .clientId("client-1") + .authorizationGrantType(AuthorizationGrantType.IMPLICIT) + .redirectUriTemplate("https://localhost:8080/redirect-uri") + .authorizationUri("https://provider.com/oauth2/auth") + .clientName("Client 1") + .build(); + + assertThatThrownBy(() -> new OAuth2ClientCredentialsGrantRequest(clientRegistration)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS"); + } + + @Test + public void constructorWhenValidParametersProvidedThenCreated() { + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); + + assertThat(clientCredentialsGrantRequest.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(clientCredentialsGrantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java index 93a30b0505..b1218d3295 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java @@ -25,6 +25,7 @@ import java.util.LinkedHashSet; import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Tests for {@link ClientRegistration}. @@ -411,4 +412,91 @@ public class ClientRegistrationTests { assertThat(registration.getRegistrationId()).isEqualTo(overriddenId); } + + @Test + public void buildWhenClientCredentialsGrantAllAttributesProvidedThenAllAttributesAreSet() { + ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .scope(SCOPES.toArray(new String[0])) + .tokenUri(TOKEN_URI) + .clientName(CLIENT_NAME) + .build(); + + assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); + assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); + assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); + assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); + assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); + assertThat(registration.getScopes()).isEqualTo(SCOPES); + assertThat(registration.getProviderDetails().getTokenUri()).isEqualTo(TOKEN_URI); + assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME); + } + + @Test + public void buildWhenClientCredentialsGrantRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + ClientRegistration.withRegistrationId(null) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .tokenUri(TOKEN_URI) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenClientCredentialsGrantClientIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + ClientRegistration.withRegistrationId(REGISTRATION_ID) + .clientId(null) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .tokenUri(TOKEN_URI) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenClientCredentialsGrantClientSecretIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + ClientRegistration.withRegistrationId(REGISTRATION_ID) + .clientId(CLIENT_ID) + .clientSecret(null) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .tokenUri(TOKEN_URI) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenClientCredentialsGrantClientAuthenticationMethodIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + ClientRegistration.withRegistrationId(REGISTRATION_ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(null) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .tokenUri(TOKEN_URI) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenClientCredentialsGrantTokenUriIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + ClientRegistration.withRegistrationId(REGISTRATION_ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .tokenUri(null) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index d67d527da0..b49a1b2c01 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -20,6 +20,7 @@ import org.junit.Before; import org.junit.Test; import org.springframework.core.MethodParameter; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; @@ -27,7 +28,16 @@ import org.springframework.security.oauth2.client.ClientAuthorizationRequiredExc import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.ReflectionUtils; import org.springframework.web.context.request.ServletWebRequest; @@ -38,8 +48,8 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; /** * Tests for {@link OAuth2AuthorizedClientArgumentResolver}. @@ -47,22 +57,58 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class OAuth2AuthorizedClientArgumentResolverTests { + private TestingAuthenticationToken authentication; + private String principalName = "principal-1"; + private ClientRegistration registration1; + private ClientRegistration registration2; + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClient authorizedClient1; + private OAuth2AuthorizedClient authorizedClient2; private OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientArgumentResolver argumentResolver; - private OAuth2AuthorizedClient authorizedClient; private MockHttpServletRequest request; @Before public void setup() { - this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository); - this.authorizedClient = mock(OAuth2AuthorizedClient.class); - this.request = new MockHttpServletRequest(); - when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) - .thenReturn(this.authorizedClient); + this.authentication = new TestingAuthenticationToken(this.principalName, "password"); SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); - securityContext.setAuthentication(mock(Authentication.class)); + securityContext.setAuthentication(this.authentication); SecurityContextHolder.setContext(securityContext); + + this.registration1 = ClientRegistration.withRegistrationId("client1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}") + .scope("user") + .authorizationUri("https://provider.com/oauth2/authorize") + .tokenUri("https://provider.com/oauth2/token") + .userInfoUri("https://provider.com/oauth2/user") + .userNameAttributeName("id") + .clientName("client-1") + .build(); + this.registration2 = ClientRegistration.withRegistrationId("client2") + .clientId("client-2") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .scope("read", "write") + .tokenUri("https://provider.com/oauth2/token") + .build(); + this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, this.registration2); + this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClient1 = new OAuth2AuthorizedClient(this.registration1, this.principalName, mock(OAuth2AccessToken.class)); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration1.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class))) + .thenReturn(this.authorizedClient1); + this.authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName, mock(OAuth2AccessToken.class)); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration2.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class))) + .thenReturn(this.authorizedClient2); + this.request = new MockHttpServletRequest(); } @After @@ -71,8 +117,20 @@ public class OAuth2AuthorizedClientArgumentResolverTests { } @Test - public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)) + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.argumentResolver.setClientCredentialsTokenResponseClient(null)) .isInstanceOf(IllegalArgumentException.class); } @@ -101,7 +159,7 @@ public class OAuth2AuthorizedClientArgumentResolverTests { } @Test - public void resolveArgumentWhenRegistrationIdEmptyAndNotOAuth2AuthenticationThenThrowIllegalArgumentException() throws Exception { + public void resolveArgumentWhenRegistrationIdEmptyAndNotOAuth2AuthenticationThenThrowIllegalArgumentException() { MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null)) .isInstanceOf(IllegalArgumentException.class) @@ -116,18 +174,26 @@ public class OAuth2AuthorizedClientArgumentResolverTests { securityContext.setAuthentication(authentication); SecurityContextHolder.setContext(securityContext); MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); - this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request), null); + assertThat(this.argumentResolver.resolveArgument( + methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient1); } @Test - public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() throws Exception { + public void resolveArgumentWhenAuthorizedClientFoundThenResolves() throws Exception { MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThat(this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient); + methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient1); } @Test - public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() throws Exception { + public void resolveArgumentWhenRegistrationIdInvalidThenDoesNotResolve() throws Exception { + MethodParameter methodParameter = this.getMethodParameter("registrationIdInvalid", OAuth2AuthorizedClient.class); + assertThat(this.argumentResolver.resolveArgument( + methodParameter, null, new ServletWebRequest(this.request), null)).isNull(); + } + + @Test + public void resolveArgumentWhenAuthorizedClientNotFoundForAuthorizationCodeClientThenThrowClientAuthorizationRequiredException() { when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) .thenReturn(null); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); @@ -135,6 +201,35 @@ public class OAuth2AuthorizedClientArgumentResolverTests { .isInstanceOf(ClientAuthorizationRequiredException.class); } + @SuppressWarnings("unchecked") + @Test + public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClientThenResolvesFromTokenResponseClient() throws Exception { + OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = + mock(OAuth2AccessTokenResponseClient.class); + this.argumentResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken("access-token-1234") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .build(); + when(clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) + .thenReturn(null); + MethodParameter methodParameter = this.getMethodParameter("clientCredentialsClient", OAuth2AuthorizedClient.class); + + OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver.resolveArgument( + methodParameter, null, new ServletWebRequest(this.request), null); + + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.registration2); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName); + assertThat(authorizedClient.getAccessToken()).isSameAs(accessTokenResponse.getAccessToken()); + + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), eq(null)); + } + private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes); return new MethodParameter(method, 0); @@ -155,5 +250,11 @@ public class OAuth2AuthorizedClientArgumentResolverTests { void registrationIdEmpty(@RegisteredOAuth2AuthorizedClient OAuth2AuthorizedClient authorizedClient) { } + + void registrationIdInvalid(@RegisteredOAuth2AuthorizedClient("invalid") OAuth2AuthorizedClient authorizedClient) { + } + + void clientCredentialsClient(@RegisteredOAuth2AuthorizedClient("client2") OAuth2AuthorizedClient authorizedClient) { + } } } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java index 4e58a2c6f1..1a0af57806 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java @@ -38,6 +38,7 @@ public final class AuthorizationGrantType implements Serializable { public static final AuthorizationGrantType AUTHORIZATION_CODE = new AuthorizationGrantType("authorization_code"); public static final AuthorizationGrantType IMPLICIT = new AuthorizationGrantType("implicit"); public static final AuthorizationGrantType REFRESH_TOKEN = new AuthorizationGrantType("refresh_token"); + public static final AuthorizationGrantType CLIENT_CREDENTIALS = new AuthorizationGrantType("client_credentials"); private final String value; /** diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java index eed944b016..c1061e7c2e 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,11 @@ package org.springframework.security.oauth2.core.endpoint; */ public interface OAuth2ParameterNames { + /** + * {@code grant_type} - used in Access Token Request. + */ + String GRANT_TYPE = "grant_type"; + /** * {@code response_type} - used in Authorization Request. */ @@ -35,6 +40,11 @@ public interface OAuth2ParameterNames { */ String CLIENT_ID = "client_id"; + /** + * {@code client_secret} - used in Access Token Request. + */ + String CLIENT_SECRET = "client_secret"; + /** * {@code redirect_uri} - used in Authorization Request and Access Token Request. */ @@ -55,6 +65,26 @@ public interface OAuth2ParameterNames { */ String CODE = "code"; + /** + * {@code access_token} - used in Authorization Response and Access Token Response. + */ + String ACCESS_TOKEN = "access_token"; + + /** + * {@code token_type} - used in Authorization Response and Access Token Response. + */ + String TOKEN_TYPE = "token_type"; + + /** + * {@code expires_in} - used in Authorization Response and Access Token Response. + */ + String EXPIRES_IN = "expires_in"; + + /** + * {@code refresh_token} - used in Access Token Request and Access Token Response. + */ + String REFRESH_TOKEN = "refresh_token"; + /** * {@code error} - used in Authorization Response and Access Token Response. */