From eae7afd9aa963581ea638a4385d49b6571fc5e74 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 30 Jan 2018 15:18:03 -0500 Subject: [PATCH] Add support for authorization_code grant Fixes gh-4928 --- .../web/builders/FilterComparator.java | 4 + .../AuthorizationCodeGrantConfigurer.java | 243 +++++++++++++++++ .../client/ImplicitGrantConfigurer.java | 16 +- .../client/OAuth2ClientConfigurerUtils.java | 69 +++++ .../oauth2/client/OAuth2LoginConfigurer.java | 46 +--- ...AuthorizationCodeGrantConfigurerTests.java | 165 ++++++++++++ ...thorizationCodeAuthenticationProvider.java | 88 ++++++ ...2AuthorizationCodeAuthenticationToken.java | 114 ++++++++ .../OAuth2AuthorizationExchangeValidator.java | 54 ++++ .../OAuth2LoginAuthenticationProvider.java | 27 +- .../OAuth2AuthorizationCodeGrantFilter.java | 192 +++++++++++++ .../web/OAuth2AuthorizationResponseUtils.java | 72 +++++ .../web/OAuth2LoginAuthenticationFilter.java | 38 +-- ...zationCodeAuthenticationProviderTests.java | 146 ++++++++++ ...orizationCodeAuthenticationTokenTests.java | 109 ++++++++ ...uth2AuthorizationCodeGrantFilterTests.java | 252 ++++++++++++++++++ .../OAuth2LoginAuthenticationFilterTests.java | 22 -- 17 files changed, 1518 insertions(+), 139 deletions(-) create mode 100644 config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeGrantConfigurer.java create mode 100644 config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java create mode 100644 config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeGrantConfigurerTests.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationToken.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationExchangeValidator.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java b/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java index 632da9b14d..20f336d609 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java @@ -117,6 +117,10 @@ final class FilterComparator implements Comparator, Serializable { order += STEP; put(AnonymousAuthenticationFilter.class, order); order += STEP; + filterToOrder.put( + "org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter", + order); + order += STEP; put(SessionManagementFilter.class, order); order += STEP; put(ExceptionTranslationFilter.class, order); diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeGrantConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeGrantConfigurer.java new file mode 100644 index 0000000000..2a15fe3b37 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeGrantConfigurer.java @@ -0,0 +1,243 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configurers.oauth2.client; + +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.config.annotation.web.HttpSecurityBuilder; +import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider; +import org.springframework.security.oauth2.client.endpoint.NimbusAuthorizationCodeTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.util.Assert; + +/** + * An {@link AbstractHttpConfigurer} for the OAuth 2.0 Authorization Code Grant. + * + *

+ * Defaults are provided for all configuration options with the only required configuration + * being {@link #clientRegistrationRepository(ClientRegistrationRepository)}. + * Alternatively, a {@link ClientRegistrationRepository} {@code @Bean} may be registered instead. + * + *

Security Filters

+ * + * The following {@code Filter}'s are populated: + * + * + * + *

Shared Objects Created

+ * + * The following shared objects are populated: + * + * + * + *

Shared Objects Used

+ * + * The following shared objects are used: + * + * + * + * @author Joe Grandja + * @since 5.1 + * @see OAuth2AuthorizationRequestRedirectFilter + * @see OAuth2AuthorizationCodeGrantFilter + * @see ClientRegistrationRepository + * @see OAuth2AuthorizedClientService + * @see AbstractHttpConfigurer + */ +public final class AuthorizationCodeGrantConfigurer> extends + AbstractHttpConfigurer, B> { + + private final AuthorizationEndpointConfig authorizationEndpointConfig = new AuthorizationEndpointConfig(); + private final TokenEndpointConfig tokenEndpointConfig = new TokenEndpointConfig(); + + /** + * Sets the repository of client registrations. + * + * @param clientRegistrationRepository the repository of client registrations + * @return the {@link AuthorizationCodeGrantConfigurer} for further configuration + */ + public AuthorizationCodeGrantConfigurer clientRegistrationRepository(ClientRegistrationRepository clientRegistrationRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); + return this; + } + + /** + * Sets the service for authorized client(s). + * + * @param authorizedClientService the authorized client service + * @return the {@link AuthorizationCodeGrantConfigurer} for further configuration + */ + public AuthorizationCodeGrantConfigurer authorizedClientService(OAuth2AuthorizedClientService authorizedClientService) { + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + this.getBuilder().setSharedObject(OAuth2AuthorizedClientService.class, authorizedClientService); + return this; + } + + /** + * Returns the {@link AuthorizationEndpointConfig} for configuring the Authorization Server's Authorization Endpoint. + * + * @return the {@link AuthorizationEndpointConfig} + */ + public AuthorizationEndpointConfig authorizationEndpoint() { + return this.authorizationEndpointConfig; + } + + /** + * Configuration options for the Authorization Server's Authorization Endpoint. + */ + public class AuthorizationEndpointConfig { + private String authorizationRequestBaseUri; + private AuthorizationRequestRepository authorizationRequestRepository; + + private AuthorizationEndpointConfig() { + } + + /** + * Sets the base {@code URI} used for authorization requests. + * + * @param authorizationRequestBaseUri the base {@code URI} used for authorization requests + * @return the {@link AuthorizationEndpointConfig} for further configuration + */ + public AuthorizationEndpointConfig baseUri(String authorizationRequestBaseUri) { + Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty"); + this.authorizationRequestBaseUri = authorizationRequestBaseUri; + return this; + } + + /** + * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. + * + * @param authorizationRequestRepository the repository used for storing {@link OAuth2AuthorizationRequest}'s + * @return the {@link AuthorizationEndpointConfig} for further configuration + */ + public AuthorizationEndpointConfig authorizationRequestRepository(AuthorizationRequestRepository authorizationRequestRepository) { + Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); + this.authorizationRequestRepository = authorizationRequestRepository; + return this; + } + + /** + * Returns the {@link AuthorizationCodeGrantConfigurer} for further configuration. + * + * @return the {@link AuthorizationCodeGrantConfigurer} + */ + public AuthorizationCodeGrantConfigurer and() { + return AuthorizationCodeGrantConfigurer.this; + } + } + + /** + * Returns the {@link TokenEndpointConfig} for configuring the Authorization Server's Token Endpoint. + * + * @return the {@link TokenEndpointConfig} + */ + public TokenEndpointConfig tokenEndpoint() { + return this.tokenEndpointConfig; + } + + /** + * Configuration options for the Authorization Server's Token Endpoint. + */ + public class TokenEndpointConfig { + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + + private TokenEndpointConfig() { + } + + /** + * Sets the client used for requesting the access token credential from the Token Endpoint. + * + * @param accessTokenResponseClient the client used for requesting the access token credential from the Token Endpoint + * @return the {@link TokenEndpointConfig} for further configuration + */ + public TokenEndpointConfig accessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { + + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Returns the {@link AuthorizationCodeGrantConfigurer} for further configuration. + * + * @return the {@link AuthorizationCodeGrantConfigurer} + */ + public AuthorizationCodeGrantConfigurer and() { + return AuthorizationCodeGrantConfigurer.this; + } + } + + @Override + public void init(B http) throws Exception { + OAuth2AccessTokenResponseClient accessTokenResponseClient = + this.tokenEndpointConfig.accessTokenResponseClient; + if (accessTokenResponseClient == null) { + accessTokenResponseClient = new NimbusAuthorizationCodeTokenResponseClient(); + } + + OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider = + new OAuth2AuthorizationCodeAuthenticationProvider(accessTokenResponseClient); + http.authenticationProvider(this.postProcess(authorizationCodeAuthenticationProvider)); + } + + @Override + public void configure(B http) throws Exception { + String authorizationRequestBaseUri = this.authorizationEndpointConfig.authorizationRequestBaseUri; + if (authorizationRequestBaseUri == null) { + authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; + } + + OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter( + OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), authorizationRequestBaseUri); + + if (this.authorizationEndpointConfig.authorizationRequestRepository != null) { + authorizationRequestFilter.setAuthorizationRequestRepository( + this.authorizationEndpointConfig.authorizationRequestRepository); + } + http.addFilter(this.postProcess(authorizationRequestFilter)); + + AuthenticationManager authenticationManager = http.getSharedObject(AuthenticationManager.class); + + OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter( + OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), + OAuth2ClientConfigurerUtils.getAuthorizedClientService(this.getBuilder()), + authenticationManager); + + if (this.authorizationEndpointConfig.authorizationRequestRepository != null) { + authorizationCodeGrantFilter.setAuthorizationRequestRepository( + this.authorizationEndpointConfig.authorizationRequestRepository); + } + http.addFilter(this.postProcess(authorizationCodeGrantFilter)); + } +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/ImplicitGrantConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/ImplicitGrantConfigurer.java index b860818a67..3917f522ec 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/ImplicitGrantConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/ImplicitGrantConfigurer.java @@ -15,7 +15,6 @@ */ package org.springframework.security.config.annotation.web.configurers.oauth2.client; -import org.springframework.context.ApplicationContext; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -86,7 +85,7 @@ public final class ImplicitGrantConfigurer> ext @Override public void configure(B http) throws Exception { OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter( - this.getClientRegistrationRepository(), this.getAuthorizationRequestBaseUri()); + OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), this.getAuthorizationRequestBaseUri()); http.addFilter(this.postProcess(authorizationRequestFilter)); } @@ -95,17 +94,4 @@ public final class ImplicitGrantConfigurer> ext this.authorizationRequestBaseUri : OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; } - - private ClientRegistrationRepository getClientRegistrationRepository() { - ClientRegistrationRepository clientRegistrationRepository = this.getBuilder().getSharedObject(ClientRegistrationRepository.class); - if (clientRegistrationRepository == null) { - clientRegistrationRepository = this.getClientRegistrationRepositoryBean(); - this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); - } - return clientRegistrationRepository; - } - - private ClientRegistrationRepository getClientRegistrationRepositoryBean() { - return this.getBuilder().getSharedObject(ApplicationContext.class).getBean(ClientRegistrationRepository.class); - } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java new file mode 100644 index 0000000000..451268e60b --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configurers.oauth2.client; + +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.context.ApplicationContext; +import org.springframework.security.config.annotation.web.HttpSecurityBuilder; +import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; +import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; + +import java.util.Map; + +/** + * Utility methods for the OAuth 2.0 Client {@link AbstractHttpConfigurer}'s. + * + * @author Joe Grandja + * @since 5.1 + */ +final class OAuth2ClientConfigurerUtils { + + private OAuth2ClientConfigurerUtils() { + } + + static > ClientRegistrationRepository getClientRegistrationRepository(B builder) { + ClientRegistrationRepository clientRegistrationRepository = builder.getSharedObject(ClientRegistrationRepository.class); + if (clientRegistrationRepository == null) { + clientRegistrationRepository = getClientRegistrationRepositoryBean(builder); + builder.setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); + } + return clientRegistrationRepository; + } + + private static > ClientRegistrationRepository getClientRegistrationRepositoryBean(B builder) { + return builder.getSharedObject(ApplicationContext.class).getBean(ClientRegistrationRepository.class); + } + + static > OAuth2AuthorizedClientService getAuthorizedClientService(B builder) { + OAuth2AuthorizedClientService authorizedClientService = builder.getSharedObject(OAuth2AuthorizedClientService.class); + if (authorizedClientService == null) { + authorizedClientService = getAuthorizedClientServiceBean(builder); + if (authorizedClientService == null) { + authorizedClientService = new InMemoryOAuth2AuthorizedClientService(getClientRegistrationRepository(builder)); + } + builder.setSharedObject(OAuth2AuthorizedClientService.class, authorizedClientService); + } + return authorizedClientService; + } + + private static > OAuth2AuthorizedClientService getAuthorizedClientServiceBean(B builder) { + Map authorizedClientServiceMap = BeanFactoryUtils.beansOfTypeIncludingAncestors( + builder.getSharedObject(ApplicationContext.class), OAuth2AuthorizedClientService.class); + return (!authorizedClientServiceMap.isEmpty() ? authorizedClientServiceMap.values().iterator().next() : null); + } +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 64ae269667..7b924691f8 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -26,7 +26,6 @@ import org.springframework.security.config.annotation.web.configurers.AbstractHt import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; -import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationProvider; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; @@ -376,8 +375,8 @@ public final class OAuth2LoginConfigurer> exten public void init(B http) throws Exception { OAuth2LoginAuthenticationFilter authenticationFilter = new OAuth2LoginAuthenticationFilter( - this.getClientRegistrationRepository(), - this.getAuthorizedClientService(), + OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), + OAuth2ClientConfigurerUtils.getAuthorizedClientService(this.getBuilder()), OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI); this.setAuthenticationFilter(authenticationFilter); this.loginProcessingUrl(OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI); @@ -442,7 +441,7 @@ public final class OAuth2LoginConfigurer> exten } OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter( - this.getClientRegistrationRepository(), authorizationRequestBaseUri); + OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), authorizationRequestBaseUri); if (this.authorizationEndpointConfig.authorizationRequestRepository != null) { authorizationRequestFilter.setAuthorizationRequestRepository( @@ -466,41 +465,6 @@ public final class OAuth2LoginConfigurer> exten return new AntPathRequestMatcher(loginProcessingUrl); } - private ClientRegistrationRepository getClientRegistrationRepository() { - ClientRegistrationRepository clientRegistrationRepository = - this.getBuilder().getSharedObject(ClientRegistrationRepository.class); - if (clientRegistrationRepository == null) { - clientRegistrationRepository = this.getClientRegistrationRepositoryBean(); - this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); - } - return clientRegistrationRepository; - } - - private ClientRegistrationRepository getClientRegistrationRepositoryBean() { - return this.getBuilder().getSharedObject(ApplicationContext.class).getBean(ClientRegistrationRepository.class); - } - - private OAuth2AuthorizedClientService getAuthorizedClientService() { - OAuth2AuthorizedClientService authorizedClientService = - this.getBuilder().getSharedObject(OAuth2AuthorizedClientService.class); - if (authorizedClientService == null) { - authorizedClientService = this.getAuthorizedClientServiceBean(); - if (authorizedClientService == null) { - authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.getClientRegistrationRepository()); - } - this.getBuilder().setSharedObject(OAuth2AuthorizedClientService.class, authorizedClientService); - } - return authorizedClientService; - } - - private OAuth2AuthorizedClientService getAuthorizedClientServiceBean() { - Map authorizedClientServiceMap = - BeanFactoryUtils.beansOfTypeIncludingAncestors( - this.getBuilder().getSharedObject(ApplicationContext.class), - OAuth2AuthorizedClientService.class); - return (!authorizedClientServiceMap.isEmpty() ? authorizedClientServiceMap.values().iterator().next() : null); - } - private GrantedAuthoritiesMapper getGrantedAuthoritiesMapper() { GrantedAuthoritiesMapper grantedAuthoritiesMapper = this.getBuilder().getSharedObject(GrantedAuthoritiesMapper.class); @@ -528,7 +492,8 @@ public final class OAuth2LoginConfigurer> exten } Iterable clientRegistrations = null; - ClientRegistrationRepository clientRegistrationRepository = this.getClientRegistrationRepository(); + ClientRegistrationRepository clientRegistrationRepository = + OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()); ResolvableType type = ResolvableType.forInstance(clientRegistrationRepository).as(Iterable.class); if (type != ResolvableType.NONE && ClientRegistration.class.isAssignableFrom(type.resolveGenerics()[0])) { clientRegistrations = (Iterable) clientRegistrationRepository; @@ -580,5 +545,4 @@ public final class OAuth2LoginConfigurer> exten return OAuth2LoginAuthenticationToken.class.isAssignableFrom(authentication); } } - } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeGrantConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeGrantConfigurerTests.java new file mode 100644 index 0000000000..6816fa7ef4 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeGrantConfigurerTests.java @@ -0,0 +1,165 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configurers.oauth2.client; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; +import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Tests for {@link AuthorizationCodeGrantConfigurer}. + * + * @author Joe Grandja + */ +@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AccessTokenResponse.class}) +@RunWith(PowerMockRunner.class) +public class AuthorizationCodeGrantConfigurerTests { + private static ClientRegistrationRepository clientRegistrationRepository; + + private static OAuth2AuthorizedClientService authorizedClientService; + + private static OAuth2AccessTokenResponseClient accessTokenResponseClient; + + @Rule + public final SpringTestRule spring = new SpringTestRule(); + + @Autowired + private MockMvc mockMvc; + + private ClientRegistration registration1; + + @Before + public void setup() { + this.registration1 = ClientRegistration.withRegistrationId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("{baseUrl}/client-1") + .scope("user") + .authorizationUri("https://provider.com/oauth2/authorize") + .tokenUri("https://provider.com/oauth2/token") + .userInfoUri("https://provider.com/oauth2/user") + .userNameAttributeName("id") + .clientName("client-1") + .build(); + clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1); + + authorizedClientService = new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository); + + OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class); + when(accessTokenResponse.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class)); + accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); + when(accessTokenResponseClient.getTokenResponse(any(OAuth2AuthorizationCodeGrantRequest.class))).thenReturn(accessTokenResponse); + } + + @Test + public void configureWhenAuthorizationRequestThenRedirectForAuthorization() throws Exception { + this.spring.register(AuthorizationCodeGrantConfig.class).autowire(); + + MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1")) + .andExpect(status().is3xxRedirection()) + .andReturn(); + assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/client-1"); + } + + @Test + public void configureWhenAuthorizationResponseSuccessThenAuthorizedClientSaved() throws Exception { + this.spring.register(AuthorizationCodeGrantConfig.class).autowire(); + + // Setup the Authorization Request in the session + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId()); + OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class); + when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters); + when(authorizationRequest.getState()).thenReturn("state"); + when(authorizationRequest.getRedirectUri()).thenReturn("http://localhost/client-1"); + MockHttpSession session = new MockHttpSession(); + session.setAttribute(HttpSessionOAuth2AuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST", authorizationRequest); + + String principalName = "user1"; + + this.mockMvc.perform(get("/client-1") + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, "state") + .with(user(principalName)) + .session(session)) + .andExpect(status().is3xxRedirection()) + .andExpect(redirectedUrl("http://localhost/client-1")); + + OAuth2AuthorizedClient authorizedClient = authorizedClientService.loadAuthorizedClient( + this.registration1.getRegistrationId(), principalName); + assertThat(authorizedClient).isNotNull(); + } + + @EnableWebSecurity + static class AuthorizationCodeGrantConfig extends WebSecurityConfigurerAdapter { + @Override + protected void configure(HttpSecurity http) throws Exception { + http + .authorizeRequests() + .anyRequest().authenticated(); + + this.authorizationCodeGrant(http) + .clientRegistrationRepository(clientRegistrationRepository) + .authorizedClientService(authorizedClientService) + .tokenEndpoint() + .accessTokenResponseClient(accessTokenResponseClient); + } + + private AuthorizationCodeGrantConfigurer authorizationCodeGrant(HttpSecurity http) throws Exception { + return http.apply(new AuthorizationCodeGrantConfigurer<>()); + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java new file mode 100644 index 0000000000..c202a31f3e --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.authentication; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; + +/** + * An implementation of an {@link AuthenticationProvider} for the OAuth 2.0 Authorization Code Grant. + * + *

+ * This {@link AuthenticationProvider} is responsible for authenticating + * an Authorization Code credential with the Authorization Server's Token Endpoint + * and if valid, exchanging it for an Access Token credential. + * + * @author Joe Grandja + * @since 5.1 + * @see OAuth2AuthorizationCodeAuthenticationToken + * @see OAuth2AccessTokenResponseClient + * @see Section 4.1 Authorization Code Grant Flow + * @see Section 4.1.3 Access Token Request + * @see Section 4.1.4 Access Token Response + */ +public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { + private final OAuth2AccessTokenResponseClient accessTokenResponseClient; + + /** + * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters. + * + * @param accessTokenResponseClient the client used for requesting the access token credential from the Token Endpoint + */ + public OAuth2AuthorizationCodeAuthenticationProvider( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { + + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = + (OAuth2AuthorizationCodeAuthenticationToken) authentication; + + OAuth2AuthorizationExchangeValidator.validate( + authorizationCodeAuthentication.getAuthorizationExchange()); + + OAuth2AccessTokenResponse accessTokenResponse = + this.accessTokenResponseClient.getTokenResponse( + new OAuth2AuthorizationCodeGrantRequest( + authorizationCodeAuthentication.getClientRegistration(), + authorizationCodeAuthentication.getAuthorizationExchange())); + + OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); + + OAuth2AuthorizationCodeAuthenticationToken authenticationResult = + new OAuth2AuthorizationCodeAuthenticationToken( + authorizationCodeAuthentication.getClientRegistration(), + authorizationCodeAuthentication.getAuthorizationExchange(), + accessToken); + authenticationResult.setDetails(authorizationCodeAuthentication.getDetails()); + + return authenticationResult; + } + + @Override + public boolean supports(Class authentication) { + return OAuth2AuthorizationCodeAuthenticationToken.class.isAssignableFrom(authentication); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationToken.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationToken.java new file mode 100644 index 0000000000..bbc6d4aca1 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationToken.java @@ -0,0 +1,114 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.authentication; + +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.core.SpringSecurityCoreVersion; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.util.Assert; + +import java.util.Collections; + +/** + * An {@link AbstractAuthenticationToken} for the OAuth 2.0 Authorization Code Grant. + * + * @author Joe Grandja + * @since 5.1 + * @see AbstractAuthenticationToken + * @see ClientRegistration + * @see OAuth2AuthorizationExchange + * @see OAuth2AccessToken + * @see Section 4.1 Authorization Code Grant Flow + */ +public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private ClientRegistration clientRegistration; + private OAuth2AuthorizationExchange authorizationExchange; + private OAuth2AccessToken accessToken; + + /** + * This constructor should be used when the Authorization Request/Response is complete. + * + * @param clientRegistration the client registration + * @param authorizationExchange the authorization exchange + */ + public OAuth2AuthorizationCodeAuthenticationToken(ClientRegistration clientRegistration, + OAuth2AuthorizationExchange authorizationExchange) { + super(Collections.emptyList()); + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + Assert.notNull(authorizationExchange, "authorizationExchange cannot be null"); + this.clientRegistration = clientRegistration; + this.authorizationExchange = authorizationExchange; + } + + /** + * This constructor should be used when the Access Token Request/Response is complete, + * which indicates that the Authorization Code Grant flow has fully completed. + * + * @param clientRegistration the client registration + * @param authorizationExchange the authorization exchange + * @param accessToken the access token credential + */ + public OAuth2AuthorizationCodeAuthenticationToken(ClientRegistration clientRegistration, + OAuth2AuthorizationExchange authorizationExchange, + OAuth2AccessToken accessToken) { + this(clientRegistration, authorizationExchange); + Assert.notNull(accessToken, "accessToken cannot be null"); + this.accessToken = accessToken; + this.setAuthenticated(true); + } + + @Override + public Object getPrincipal() { + return this.clientRegistration.getClientId(); + } + + @Override + public Object getCredentials() { + return this.accessToken != null ? + this.accessToken.getTokenValue() : + this.authorizationExchange.getAuthorizationResponse().getCode(); + } + + /** + * Returns the {@link ClientRegistration client registration}. + * + * @return the {@link ClientRegistration} + */ + public ClientRegistration getClientRegistration() { + return this.clientRegistration; + } + + /** + * Returns the {@link OAuth2AuthorizationExchange authorization exchange}. + * + * @return the {@link OAuth2AuthorizationExchange} + */ + public OAuth2AuthorizationExchange getAuthorizationExchange() { + return this.authorizationExchange; + } + + /** + * Returns the {@link OAuth2AccessToken access token}. + * + * @return the {@link OAuth2AccessToken} + */ + public OAuth2AccessToken getAccessToken() { + return this.accessToken; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationExchangeValidator.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationExchangeValidator.java new file mode 100644 index 0000000000..a23b09f291 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationExchangeValidator.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.authentication; + +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; + +/** + * A validator for an "exchange" of an OAuth 2.0 Authorization Request and Response. + * + * @author Joe Grandja + * @since 5.1 + * @see OAuth2AuthorizationExchange + */ +final class OAuth2AuthorizationExchangeValidator { + private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter"; + private static final String INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE = "invalid_redirect_uri_parameter"; + + static void validate(OAuth2AuthorizationExchange authorizationExchange) { + OAuth2AuthorizationRequest authorizationRequest = authorizationExchange.getAuthorizationRequest(); + OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse(); + + if (authorizationResponse.statusError()) { + throw new OAuth2AuthenticationException( + authorizationResponse.getError(), authorizationResponse.getError().toString()); + } + + if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + + if (!authorizationResponse.getRedirectUri().equals(authorizationRequest.getRedirectUri())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java index ab453f7c59..d3c442f334 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java @@ -25,11 +25,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; @@ -60,8 +56,6 @@ import java.util.Collection; * @see Section 4.1.4 Access Token Response */ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider { - private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter"; - private static final String INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE = "invalid_redirect_uri_parameter"; private final OAuth2AccessTokenResponseClient accessTokenResponseClient; private final OAuth2UserService userService; private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); @@ -97,25 +91,8 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider return null; } - OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication - .getAuthorizationExchange().getAuthorizationRequest(); - OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication - .getAuthorizationExchange().getAuthorizationResponse(); - - if (authorizationResponse.statusError()) { - throw new OAuth2AuthenticationException( - authorizationResponse.getError(), authorizationResponse.getError().toString()); - } - - if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - - if (!authorizationResponse.getRedirectUri().equals(authorizationRequest.getRedirectUri())) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } + OAuth2AuthorizationExchangeValidator.validate( + authorizationCodeAuthentication.getAuthorizationExchange()); OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenResponseClient.getTokenResponse( diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java new file mode 100644 index 0000000000..44b9ce6c8e --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java @@ -0,0 +1,192 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.security.authentication.AuthenticationDetailsSource; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.web.DefaultRedirectStrategy; +import org.springframework.security.web.RedirectStrategy; +import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.util.UriComponentsBuilder; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; + +/** + * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, + * which handles the processing of the OAuth 2.0 Authorization Response. + * + *

+ * The OAuth 2.0 Authorization Response is processed as follows: + * + *

    + *
  • + * Assuming the End-User (Resource Owner) has granted access to the Client, the Authorization Server will append the + * {@link OAuth2ParameterNames#CODE code} and {@link OAuth2ParameterNames#STATE state} parameters + * to the {@link OAuth2ParameterNames#REDIRECT_URI redirect_uri} (provided in the Authorization Request) + * and redirect the End-User's user-agent back to this {@code Filter} (the Client). + *
  • + *
  • + * This {@code Filter} will then create an {@link OAuth2AuthorizationCodeAuthenticationToken} with + * the {@link OAuth2ParameterNames#CODE code} received and + * delegate it to the {@link AuthenticationManager} to authenticate. + *
  • + *
  • + * Upon a successful authentication, an {@link OAuth2AuthorizedClient Authorized Client} is created by associating the + * {@link OAuth2AuthorizationCodeAuthenticationToken#getClientRegistration() client} to the + * {@link OAuth2AuthorizationCodeAuthenticationToken#getAccessToken() access token} and current {@code Principal} + * and saving it via the {@link OAuth2AuthorizedClientService}. + *
  • + *
+ * + * @author Joe Grandja + * @since 5.1 + * @see OAuth2AuthorizationCodeAuthenticationToken + * @see OAuth2AuthorizationCodeAuthenticationProvider + * @see OAuth2AuthorizationRequest + * @see OAuth2AuthorizationResponse + * @see AuthorizationRequestRepository + * @see OAuth2AuthorizationRequestRedirectFilter + * @see ClientRegistrationRepository + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizedClientService + * @see Section 4.1 Authorization Code Grant + * @see Section 4.1.2 Authorization Response + */ +public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientService authorizedClientService; + private final AuthenticationManager authenticationManager; + private AuthorizationRequestRepository authorizationRequestRepository = + new HttpSessionOAuth2AuthorizationRequestRepository(); + private final AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); + private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); + + /** + * Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientService the authorized client service + * @param authenticationManager the authentication manager + */ + public OAuth2AuthorizationCodeGrantFilter(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientService authorizedClientService, + AuthenticationManager authenticationManager) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + Assert.notNull(authenticationManager, "authenticationManager cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientService = authorizedClientService; + this.authenticationManager = authenticationManager; + } + + /** + * Sets the repository for stored {@link OAuth2AuthorizationRequest}'s. + * + * @param authorizationRequestRepository the repository for stored {@link OAuth2AuthorizationRequest}'s + */ + public final void setAuthorizationRequestRepository(AuthorizationRequestRepository authorizationRequestRepository) { + Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); + this.authorizationRequestRepository = authorizationRequestRepository; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + if (this.shouldProcessAuthorizationResponse(request, response)) { + this.processAuthorizationResponse(request, response); + return; + } + + filterChain.doFilter(request, response); + } + + private boolean shouldProcessAuthorizationResponse(HttpServletRequest request, HttpServletResponse response) { + if (OAuth2AuthorizationResponseUtils.authorizationResponse(request) && + (this.authorizationRequestRepository.loadAuthorizationRequest(request) != null)) { + return true; + } + return false; + } + + private void processAuthorizationResponse(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request); + this.authorizationRequestRepository.removeAuthorizationRequest(request); + + String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); + + OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(request); + + OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken( + clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); + authenticationRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); + + OAuth2AuthorizationCodeAuthenticationToken authenticationResult; + + try { + authenticationResult = (OAuth2AuthorizationCodeAuthenticationToken) + this.authenticationManager.authenticate(authenticationRequest); + } catch (OAuth2AuthenticationException ex) { + OAuth2Error error = ex.getError(); + UriComponentsBuilder uriBuilder = UriComponentsBuilder + .fromUriString(authorizationResponse.getRedirectUri()) + .queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode()); + if (!StringUtils.isEmpty(error.getDescription())) { + uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription()); + } + if (!StringUtils.isEmpty(error.getUri())) { + uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri()); + } + this.redirectStrategy.sendRedirect(request, response, uriBuilder.build().encode().toString()); + return; + } + + Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication(); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + authenticationResult.getClientRegistration(), + currentAuthentication.getName(), + authenticationResult.getAccessToken()); + + this.authorizedClientService.saveAuthorizedClient(authorizedClient, currentAuthentication); + + this.redirectStrategy.sendRedirect(request, response, authorizationResponse.getRedirectUri()); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java new file mode 100644 index 0000000000..713816a326 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.StringUtils; + +import javax.servlet.http.HttpServletRequest; + +/** + * Utility methods for an OAuth 2.0 Authorization Response. + * + * @author Joe Grandja + * @since 5.1 + * @see OAuth2AuthorizationResponse + */ +final class OAuth2AuthorizationResponseUtils { + + private OAuth2AuthorizationResponseUtils() { + } + + static boolean authorizationResponse(HttpServletRequest request) { + return authorizationResponseSuccess(request) || authorizationResponseError(request); + } + + static boolean authorizationResponseSuccess(HttpServletRequest request) { + return StringUtils.hasText(request.getParameter(OAuth2ParameterNames.CODE)) && + StringUtils.hasText(request.getParameter(OAuth2ParameterNames.STATE)); + } + + static boolean authorizationResponseError(HttpServletRequest request) { + return StringUtils.hasText(request.getParameter(OAuth2ParameterNames.ERROR)) && + StringUtils.hasText(request.getParameter(OAuth2ParameterNames.STATE)); + } + + static OAuth2AuthorizationResponse convert(HttpServletRequest request) { + String code = request.getParameter(OAuth2ParameterNames.CODE); + String errorCode = request.getParameter(OAuth2ParameterNames.ERROR); + String state = request.getParameter(OAuth2ParameterNames.STATE); + String redirectUri = request.getRequestURL().toString(); + + if (StringUtils.hasText(code)) { + return OAuth2AuthorizationResponse.success(code) + .redirectUri(redirectUri) + .state(state) + .build(); + } else { + String errorDescription = request.getParameter(OAuth2ParameterNames.ERROR_DESCRIPTION); + String errorUri = request.getParameter(OAuth2ParameterNames.ERROR_URI); + return OAuth2AuthorizationResponse.error(errorCode) + .redirectUri(redirectUri) + .errorDescription(errorDescription) + .errorUri(errorUri) + .state(state) + .build(); + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java index 67da59488e..8c7e5314dc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java @@ -35,7 +35,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -134,7 +133,7 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException, IOException, ServletException { - if (!this.authorizationResponseSuccess(request) && !this.authorizationResponseError(request)) { + if (!OAuth2AuthorizationResponseUtils.authorizationResponse(request)) { OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } @@ -149,7 +148,7 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID); ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); - OAuth2AuthorizationResponse authorizationResponse = this.convert(request); + OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(request); OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken( clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); @@ -182,37 +181,4 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); this.authorizationRequestRepository = authorizationRequestRepository; } - - private OAuth2AuthorizationResponse convert(HttpServletRequest request) { - String code = request.getParameter(OAuth2ParameterNames.CODE); - String errorCode = request.getParameter(OAuth2ParameterNames.ERROR); - String state = request.getParameter(OAuth2ParameterNames.STATE); - String redirectUri = request.getRequestURL().toString(); - - if (StringUtils.hasText(code)) { - return OAuth2AuthorizationResponse.success(code) - .redirectUri(redirectUri) - .state(state) - .build(); - } else { - String errorDescription = request.getParameter(OAuth2ParameterNames.ERROR_DESCRIPTION); - String errorUri = request.getParameter(OAuth2ParameterNames.ERROR_URI); - return OAuth2AuthorizationResponse.error(errorCode) - .redirectUri(redirectUri) - .errorDescription(errorDescription) - .errorUri(errorUri) - .state(state) - .build(); - } - } - - private boolean authorizationResponseSuccess(HttpServletRequest request) { - return StringUtils.hasText(request.getParameter(OAuth2ParameterNames.CODE)) && - StringUtils.hasText(request.getParameter(OAuth2ParameterNames.STATE)); - } - - private boolean authorizationResponseError(HttpServletRequest request) { - return StringUtils.hasText(request.getParameter(OAuth2ParameterNames.ERROR)) && - StringUtils.hasText(request.getParameter(OAuth2ParameterNames.STATE)); - } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java new file mode 100644 index 0000000000..aee80f32ee --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -0,0 +1,146 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.authentication; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; + +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.CoreMatchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}. + * + * @author Joe Grandja + */ +@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, + OAuth2AuthorizationResponse.class, OAuth2AccessTokenResponse.class}) +@RunWith(PowerMockRunner.class) +public class OAuth2AuthorizationCodeAuthenticationProviderTests { + private ClientRegistration clientRegistration; + private OAuth2AuthorizationRequest authorizationRequest; + private OAuth2AuthorizationResponse authorizationResponse; + private OAuth2AuthorizationExchange authorizationExchange; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider; + + @Rule + public ExpectedException exception = ExpectedException.none(); + + @Before + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + this.clientRegistration = mock(ClientRegistration.class); + this.authorizationRequest = mock(OAuth2AuthorizationRequest.class); + this.authorizationResponse = mock(OAuth2AuthorizationResponse.class); + this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse); + this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); + this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient); + + when(this.authorizationRequest.getState()).thenReturn("12345"); + when(this.authorizationResponse.getState()).thenReturn("12345"); + when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com"); + when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com"); + } + + @Test + public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() { + this.exception.expect(IllegalArgumentException.class); + new OAuth2AuthorizationCodeAuthenticationProvider(null); + } + + @Test + public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() { + this.exception.expect(OAuth2AuthenticationException.class); + this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST)); + + when(this.authorizationResponse.statusError()).thenReturn(true); + when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + + this.authenticationProvider.authenticate( + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + } + + @Test + public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() { + this.exception.expect(OAuth2AuthenticationException.class); + this.exception.expectMessage(containsString("invalid_state_parameter")); + + when(this.authorizationRequest.getState()).thenReturn("12345"); + when(this.authorizationResponse.getState()).thenReturn("67890"); + + this.authenticationProvider.authenticate( + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + } + + @Test + public void authenticateWhenAuthorizationResponseRedirectUriNotEqualAuthorizationRequestRedirectUriThenThrowOAuth2AuthenticationException() { + this.exception.expect(OAuth2AuthenticationException.class); + this.exception.expectMessage(containsString("invalid_redirect_uri_parameter")); + + when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com"); + when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com"); + + this.authenticationProvider.authenticate( + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + } + + @Test + public void authenticateWhenAuthorizationSuccessResponseThenExchangedForAccessToken() { + OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class); + OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class); + when(accessTokenResponse.getAccessToken()).thenReturn(accessToken); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + OAuth2AuthorizationCodeAuthenticationToken authenticationResult = + (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider.authenticate( + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal()).isEqualTo(this.clientRegistration.getClientId()); + assertThat(authenticationResult.getCredentials()).isEqualTo(accessToken.getTokenValue()); + assertThat(authenticationResult.getAuthorities()).isEqualTo(Collections.emptyList()); + assertThat(authenticationResult.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authenticationResult.getAuthorizationExchange()).isEqualTo(this.authorizationExchange); + assertThat(authenticationResult.getAccessToken()).isEqualTo(accessToken); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java new file mode 100644 index 0000000000..ed2c75a782 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java @@ -0,0 +1,109 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.authentication; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; + +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}. + * + * @author Joe Grandja + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class, OAuth2AuthorizationResponse.class}) +public class OAuth2AuthorizationCodeAuthenticationTokenTests { + private ClientRegistration clientRegistration; + private OAuth2AuthorizationExchange authorizationExchange; + private OAuth2AccessToken accessToken; + + @Before + public void setUp() { + this.clientRegistration = mock(ClientRegistration.class); + this.authorizationExchange = mock(OAuth2AuthorizationExchange.class); + this.accessToken = mock(OAuth2AccessToken.class); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorAuthorizationRequestResponseWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { + new OAuth2AuthorizationCodeAuthenticationToken(null, this.authorizationExchange); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorAuthorizationRequestResponseWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() { + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, null); + } + + @Test + public void constructorAuthorizationRequestResponseWhenAllParametersProvidedAndValidThenCreated() { + OAuth2AuthorizationResponse authorizationResponse = mock(OAuth2AuthorizationResponse.class); + when(authorizationResponse.getCode()).thenReturn("code"); + when(this.authorizationExchange.getAuthorizationResponse()).thenReturn(authorizationResponse); + + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange); + + assertThat(authentication.getPrincipal()).isEqualTo(this.clientRegistration.getClientId()); + assertThat(authentication.getCredentials()).isEqualTo(this.authorizationExchange.getAuthorizationResponse().getCode()); + assertThat(authentication.getAuthorities()).isEqualTo(Collections.emptyList()); + assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange); + assertThat(authentication.getAccessToken()).isNull(); + assertThat(authentication.isAuthenticated()).isEqualTo(false); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorTokenRequestResponseWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { + new OAuth2AuthorizationCodeAuthenticationToken(null, this.authorizationExchange, this.accessToken); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorTokenRequestResponseWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() { + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, null, this.accessToken); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorTokenRequestResponseWhenAccessTokenIsNullThenThrowIllegalArgumentException() { + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange, null); + } + + @Test + public void constructorTokenRequestResponseWhenAllParametersProvidedAndValidThenCreated() { + OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( + this.clientRegistration, this.authorizationExchange, this.accessToken); + + assertThat(authentication.getPrincipal()).isEqualTo(this.clientRegistration.getClientId()); + assertThat(authentication.getCredentials()).isEqualTo(this.accessToken.getTokenValue()); + assertThat(authentication.getAuthorities()).isEqualTo(Collections.emptyList()); + assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange); + assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken); + assertThat(authentication.isAuthenticated()).isEqualTo(true); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java new file mode 100644 index 0000000000..9ea87d5625 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java @@ -0,0 +1,252 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link OAuth2AuthorizationCodeGrantFilter}. + * + * @author Joe Grandja + */ +@PowerMockIgnore("javax.security.*") +@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class, OAuth2AuthorizationCodeGrantFilter.class}) +@RunWith(PowerMockRunner.class) +public class OAuth2AuthorizationCodeGrantFilterTests { + private ClientRegistration registration1; + private String principalName1 = "principal-1"; + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientService authorizedClientService; + private AuthenticationManager authenticationManager; + private AuthorizationRequestRepository authorizationRequestRepository; + private OAuth2AuthorizationCodeGrantFilter filter; + + @Before + public void setUp() { + this.registration1 = ClientRegistration.withRegistrationId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("{baseUrl}/callback/client-1") + .scope("user") + .authorizationUri("https://provider.com/oauth2/authorize") + .tokenUri("https://provider.com/oauth2/token") + .userInfoUri("https://provider.com/oauth2/user") + .userNameAttributeName("id") + .clientName("client-1") + .build(); + this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1); + this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository); + this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); + this.authenticationManager = mock(AuthenticationManager.class); + this.filter = spy(new OAuth2AuthorizationCodeGrantFilter( + this.clientRegistrationRepository, this.authorizedClientService, this.authenticationManager)); + this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(new TestingAuthenticationToken(this.principalName1, "password")); + SecurityContextHolder.setContext(securityContext); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + new OAuth2AuthorizationCodeGrantFilter(null, this.authorizedClientService, this.authenticationManager); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { + new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, null, this.authenticationManager); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorWhenAuthenticationManagerIsNullThenThrowIllegalArgumentException() { + new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, this.authorizedClientService, null); + } + + @Test(expected = IllegalArgumentException.class) + public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() { + this.filter.setAuthorizationRequestRepository(null); + } + + @Test + public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + // NOTE: A valid Authorization Response contains either a 'code' or 'error' parameter. + + HttpServletResponse response = mock(HttpServletResponse.class); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenAuthorizationRequestNotFoundThenNotProcessed() throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + HttpServletResponse response = mock(HttpServletResponse.class); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception { + String requestUri = "/callback/client-1"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.setUpAuthorizationRequest(request, response, this.registration1); + this.setUpAuthenticationResult(this.registration1); + + this.filter.doFilter(request, response, filterChain); + + assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull(); + } + + @Test + public void doFilterWhenAuthenticationFailsThenHandleOAuth2AuthenticationException() throws Exception { + String requestUri = "/callback/client-1"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.setUpAuthorizationRequest(request, response, this.registration1); + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT); + when(this.authenticationManager.authenticate(any(Authentication.class))) + .thenThrow(new OAuth2AuthenticationException(error, error.toString())); + + this.filter.doFilter(request, response, filterChain); + + assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1?error=invalid_grant"); + } + + @Test + public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSaved() throws Exception { + String requestUri = "/callback/client-1"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.setUpAuthorizationRequest(request, response, this.registration1); + this.setUpAuthenticationResult(this.registration1); + + this.filter.doFilter(request, response, filterChain); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( + this.registration1.getRegistrationId(), this.principalName1); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName1); + assertThat(authorizedClient.getAccessToken()).isNotNull(); + } + + @Test + public void doFilterWhenAuthorizationResponseSuccessThenRedirected() throws Exception { + String requestUri = "/callback/client-1"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.setUpAuthorizationRequest(request, response, this.registration1); + this.setUpAuthenticationResult(this.registration1); + + this.filter.doFilter(request, response, filterChain); + + assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1"); + } + + private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, + ClientRegistration registration) { + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId()); + OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class); + when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); + } + + private void setUpAuthenticationResult(ClientRegistration registration) { + OAuth2AuthorizationCodeAuthenticationToken authentication = mock(OAuth2AuthorizationCodeAuthenticationToken.class); + when(authentication.getClientRegistration()).thenReturn(registration); + when(authentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class)); + when(authentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class)); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authentication); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java index 35f6635fd8..4ff9455c2d 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java @@ -19,7 +19,6 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; -import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; @@ -54,9 +53,7 @@ import java.util.HashMap; import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.fail; import static org.mockito.Mockito.*; -import static org.powermock.api.mockito.PowerMockito.verifyPrivate; /** * Tests for {@link OAuth2LoginAuthenticationFilter}. @@ -265,25 +262,6 @@ public class OAuth2LoginAuthenticationFilterTests { verify(this.filter).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class)); } - @Test - public void attemptAuthenticationWhenAuthorizationRequestIsNullThenAuthorizationResponseNotCreated() throws Exception { - OAuth2LoginAuthenticationFilter filter = PowerMockito.spy(new OAuth2LoginAuthenticationFilter( - this.clientRegistrationRepository, this.authorizedClientService)); - - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); - - MockHttpServletResponse response = new MockHttpServletResponse(); - - try { - filter.attemptAuthentication(request, response); - fail(); - } catch (OAuth2AuthenticationException ex) { - verifyPrivate(filter, never()).invoke("convert", any(HttpServletRequest.class)); - } - } - private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, ClientRegistration registration) { OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);