diff --git a/config/spring-security-config.gradle b/config/spring-security-config.gradle index b6c6661299..0984a4cc82 100644 --- a/config/spring-security-config.gradle +++ b/config/spring-security-config.gradle @@ -66,6 +66,7 @@ dependencies { testRuntime 'cglib:cglib-nodep' testRuntime 'org.hsqldb:hsqldb' + testRuntime 'com.fasterxml.jackson.core:jackson-databind' } test { diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index e8dfb0e214..65d004264a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -65,6 +65,7 @@ import java.util.Map; * A security configurer for OAuth 2.0 / OpenID Connect 1.0 login. * * @author Joe Grandja + * @author Kazuki Shimizu * @since 5.0 */ public final class OAuth2LoginConfigurer> extends @@ -175,7 +176,6 @@ public final class OAuth2LoginConfigurer> exten private OAuth2UserService userService; private OAuth2UserService oidcUserService; private Map> customUserTypes = new HashMap<>(); - private GrantedAuthoritiesMapper userAuthoritiesMapper; private UserInfoEndpointConfig() { } @@ -201,7 +201,7 @@ public final class OAuth2LoginConfigurer> exten public UserInfoEndpointConfig userAuthoritiesMapper(GrantedAuthoritiesMapper userAuthoritiesMapper) { Assert.notNull(userAuthoritiesMapper, "userAuthoritiesMapper cannot be null"); - this.userAuthoritiesMapper = userAuthoritiesMapper; + OAuth2LoginConfigurer.this.getBuilder().setSharedObject(GrantedAuthoritiesMapper.class, userAuthoritiesMapper); return this; } @@ -244,9 +244,9 @@ public final class OAuth2LoginConfigurer> exten OAuth2LoginAuthenticationProvider oauth2LoginAuthenticationProvider = new OAuth2LoginAuthenticationProvider(accessTokenResponseClient, oauth2UserService); - if (this.userInfoEndpointConfig.userAuthoritiesMapper != null) { - oauth2LoginAuthenticationProvider.setAuthoritiesMapper( - this.userInfoEndpointConfig.userAuthoritiesMapper); + GrantedAuthoritiesMapper userAuthoritiesMapper = this.getGrantedAuthoritiesMapper(); + if (userAuthoritiesMapper != null) { + oauth2LoginAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper); } http.authenticationProvider(this.postProcess(oauth2LoginAuthenticationProvider)); @@ -261,9 +261,8 @@ public final class OAuth2LoginConfigurer> exten OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider = new OidcAuthorizationCodeAuthenticationProvider(accessTokenResponseClient, oidcUserService); - if (this.userInfoEndpointConfig.userAuthoritiesMapper != null) { - oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper( - this.userInfoEndpointConfig.userAuthoritiesMapper); + if (userAuthoritiesMapper != null) { + oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper); } http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider)); } else { @@ -340,6 +339,26 @@ public final class OAuth2LoginConfigurer> exten return (!authorizedClientServiceMap.isEmpty() ? authorizedClientServiceMap.values().iterator().next() : null); } + private GrantedAuthoritiesMapper getGrantedAuthoritiesMapper() { + GrantedAuthoritiesMapper grantedAuthoritiesMapper = + this.getBuilder().getSharedObject(GrantedAuthoritiesMapper.class); + if (grantedAuthoritiesMapper == null) { + grantedAuthoritiesMapper = this.getGrantedAuthoritiesMapperBean(); + if (grantedAuthoritiesMapper != null) { + this.getBuilder().setSharedObject(GrantedAuthoritiesMapper.class, grantedAuthoritiesMapper); + } + } + return grantedAuthoritiesMapper; + } + + private GrantedAuthoritiesMapper getGrantedAuthoritiesMapperBean() { + Map grantedAuthoritiesMapperMap = + BeanFactoryUtils.beansOfTypeIncludingAncestors( + this.getBuilder().getSharedObject(ApplicationContext.class), + GrantedAuthoritiesMapper.class); + return (!grantedAuthoritiesMapperMap.isEmpty() ? grantedAuthoritiesMapperMap.values().iterator().next() : null); + } + private void initDefaultLoginFilter(B http) { DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http.getSharedObject(DefaultLoginPageGeneratingFilter.class); if (loginPageGeneratingFilter == null || this.isCustomLoginPage()) { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTest.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTest.java new file mode 100644 index 0000000000..024eb19ea2 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTest.java @@ -0,0 +1,446 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configurers.oauth2.client; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.beans.PropertyAccessorFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.mock.web.MockFilterChain; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; +import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; +import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; +import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository; +import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.context.HttpRequestResponseHolder; +import org.springframework.security.web.context.HttpSessionSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; + +import javax.servlet.ServletException; +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OAuth2LoginConfigurer}. + * + * @author Kazuki Shimizu + * @since 5.0.1 + */ +public class OAuth2LoginConfigurerTest { + + private static final ClientRegistration CLIENT_REGISTRATION = CommonOAuth2Provider.GOOGLE + .getBuilder("google").clientId("clientId").clientSecret("clientSecret") + .build(); + + private ConfigurableApplicationContext context; + + @Autowired + private FilterChainProxy springSecurityFilterChain; + @Autowired + private AuthorizationRequestRepository authorizationRequestRepository; + @Autowired + SecurityContextRepository securityContextRepository; + + private MockHttpServletRequest request; + private MockHttpServletResponse response; + private MockFilterChain filterChain; + + @Before + public void setup() { + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + this.filterChain = new MockFilterChain(); + + this.request.setMethod("GET"); + this.request.setServletPath("/login/oauth2/code/google"); + } + + @After + public void cleanup() { + if (this.context != null) { + this.context.close(); + } + } + + @Test + public void oauth2Login() throws IOException, ServletException { + + // setup application context + loadConfig(OAuth2LoginConfig.class); + + // setup authorization request + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, + this.request, this.response); + + // setup authentication parameters + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + + // perform test + this.springSecurityFilterChain.doFilter(this.request, this.response, + this.filterChain); + + // assertions + Authentication authentication = this.securityContextRepository + .loadContext(new HttpRequestResponseHolder(this.request, this.response)) + .getAuthentication(); + assertThat(authentication.getAuthorities()).hasSize(1); + assertThat(authentication.getAuthorities()).first() + .isInstanceOf(OAuth2UserAuthority.class).hasToString("ROLE_USER"); + } + + @Test + public void oauth2LoginCustomizeUsingConfigurerMethod() + throws IOException, ServletException { + + // setup application context + loadConfig(OAuth2LoginConfigCustomizeUsingConfigurerMethod.class); + + // setup authorization request + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, + this.request, this.response); + + // setup authentication parameters + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + + // perform test + this.springSecurityFilterChain.doFilter(this.request, this.response, + this.filterChain); + + // assertions + Authentication authentication = this.securityContextRepository + .loadContext(new HttpRequestResponseHolder(this.request, this.response)) + .getAuthentication(); + assertThat(authentication.getAuthorities()).hasSize(2); + assertThat(authentication.getAuthorities()).first().hasToString("ROLE_USER"); + assertThat(authentication.getAuthorities()).last() + .hasToString("ROLE_OAUTH2_USER"); + } + + @Test + public void oauth2LoginCustomizeUsingAutoDetection() + throws IOException, ServletException { + + // setup application context + loadConfig(OAuth2LoginConfigCustomizeUsingAutoDetection.class); + + // setup authorization request + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, + this.request, this.response); + + // setup authentication parameters + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + + // perform test + this.springSecurityFilterChain.doFilter(this.request, this.response, + this.filterChain); + + // assertions + Authentication authentication = this.securityContextRepository + .loadContext(new HttpRequestResponseHolder(this.request, this.response)) + .getAuthentication(); + assertThat(authentication.getAuthorities()).hasSize(2); + assertThat(authentication.getAuthorities()).first().hasToString("ROLE_USER"); + assertThat(authentication.getAuthorities()).last() + .hasToString("ROLE_OAUTH2_USER"); + } + + @Test + public void oidcLogin() throws IOException, ServletException { + + // setup application context + loadConfig(OAuth2LoginConfig.class); + registerJwtDecoder(); + + // setup authorization request + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest( + "openid"); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, + this.request, this.response); + + // setup authentication parameters + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + + // perform test + this.springSecurityFilterChain.doFilter(this.request, this.response, + this.filterChain); + + // assertions + Authentication authentication = this.securityContextRepository + .loadContext(new HttpRequestResponseHolder(this.request, this.response)) + .getAuthentication(); + assertThat(authentication.getAuthorities()).hasSize(1); + assertThat(authentication.getAuthorities()).first() + .isInstanceOf(OidcUserAuthority.class).hasToString("ROLE_USER"); + } + + @Test + public void oidcLoginCustomizeUsingConfigurerMethod() + throws IOException, ServletException { + + // setup application context + loadConfig(OAuth2LoginConfigCustomizeUsingConfigurerMethod.class); + registerJwtDecoder(); + + // setup authorization request + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest( + "openid"); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, + this.request, this.response); + + // setup authentication parameters + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + + // perform test + this.springSecurityFilterChain.doFilter(this.request, this.response, + this.filterChain); + + // assertions + Authentication authentication = this.securityContextRepository + .loadContext(new HttpRequestResponseHolder(this.request, this.response)) + .getAuthentication(); + assertThat(authentication.getAuthorities()).hasSize(2); + assertThat(authentication.getAuthorities()).first().hasToString("ROLE_USER"); + assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OIDC_USER"); + } + + @Test + public void oidcLoginCustomizeUsingAutoDetection() + throws IOException, ServletException { + + // setup application context + loadConfig(OAuth2LoginConfigCustomizeUsingAutoDetection.class); + registerJwtDecoder(); + + // setup authorization request + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest( + "openid"); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, + this.request, this.response); + + // setup authentication parameters + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + + // perform test + this.springSecurityFilterChain.doFilter(this.request, this.response, + this.filterChain); + + // assertions + Authentication authentication = this.securityContextRepository + .loadContext(new HttpRequestResponseHolder(this.request, this.response)) + .getAuthentication(); + assertThat(authentication.getAuthorities()).hasSize(2); + assertThat(authentication.getAuthorities()).first().hasToString("ROLE_USER"); + assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OIDC_USER"); + } + + private void loadConfig(Class... configs) { + AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext(); + applicationContext.register(configs); + applicationContext.refresh(); + applicationContext.getAutowireCapableBeanFactory().autowireBean(this); + this.context = applicationContext; + } + + private void registerJwtDecoder() { + JwtDecoder decoder = token -> { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.SUB, "sub123"); + claims.put(IdTokenClaimNames.ISS, "http://localhost/iss"); + claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d")); + claims.put(IdTokenClaimNames.AZP, "clientId"); + return new Jwt("token123", Instant.now(), Instant.now().plusSeconds(3600), + Collections.singletonMap("header1", "value1"), claims); + }; + this.springSecurityFilterChain.getFilters("/login/oauth2/code/google").stream() + .filter(OAuth2LoginAuthenticationFilter.class::isInstance).findFirst() + .ifPresent(filter -> PropertyAccessorFactory.forDirectFieldAccess(filter) + .setPropertyValue( + "authenticationManager.providers[2].jwtDecoders['google']", + decoder)); + } + + private OAuth2AuthorizationRequest createOAuth2AuthorizationRequest( + String... scopes) { + return OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri( + CLIENT_REGISTRATION.getProviderDetails().getAuthorizationUri()) + .clientId(CLIENT_REGISTRATION.getClientId()).state("state123") + .redirectUri("http://localhost") + .additionalParameters( + Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, + CLIENT_REGISTRATION.getRegistrationId())) + .scope(scopes).build(); + } + + @EnableWebSecurity + static class OAuth2LoginConfig extends CommonWebSecurityConfigurerAdapter { + @Override + protected void configure(HttpSecurity http) throws Exception { + http.oauth2Login().clientRegistrationRepository( + new InMemoryClientRegistrationRepository(CLIENT_REGISTRATION)); + super.configure(http); + } + } + + @EnableWebSecurity + static class OAuth2LoginConfigCustomizeUsingConfigurerMethod + extends CommonWebSecurityConfigurerAdapter { + @Override + protected void configure(HttpSecurity http) throws Exception { + http.oauth2Login() + .clientRegistrationRepository( + new InMemoryClientRegistrationRepository(CLIENT_REGISTRATION)) + .userInfoEndpoint() + .userAuthoritiesMapper(createGrantedAuthoritiesMapper()); + super.configure(http); + } + } + + @EnableWebSecurity + static class OAuth2LoginConfigCustomizeUsingAutoDetection + extends CommonWebSecurityConfigurerAdapter { + @Override + protected void configure(HttpSecurity http) throws Exception { + http.oauth2Login(); + super.configure(http); + } + + @Bean + ClientRegistrationRepository clientRegistrationRepository() { + return new InMemoryClientRegistrationRepository(CLIENT_REGISTRATION); + } + + @Bean + GrantedAuthoritiesMapper grantedAuthoritiesMapper() { + return createGrantedAuthoritiesMapper(); + } + } + + private static abstract class CommonWebSecurityConfigurerAdapter + extends WebSecurityConfigurerAdapter { + @Override + protected void configure(HttpSecurity http) throws Exception { + http.securityContext().securityContextRepository(securityContextRepository()) + .and().oauth2Login().tokenEndpoint() + .accessTokenResponseClient(createOauth2AccessTokenResponseClient()) + .and().userInfoEndpoint().userService(createOauth2UserService()) + .oidcUserService(createOidcUserService()); + } + + @Bean + SecurityContextRepository securityContextRepository() { + return new HttpSessionSecurityContextRepository(); + } + + @Bean + HttpSessionOAuth2AuthorizationRequestRepository oauth2AuthorizationRequestRepository() { + return new HttpSessionOAuth2AuthorizationRequestRepository(); + } + } + + private static OAuth2AccessTokenResponseClient createOauth2AccessTokenResponseClient() { + return request -> { + Map additionalParameters = new HashMap<>(); + if (request.getAuthorizationExchange().getAuthorizationRequest().getScopes() + .contains("openid")) { + additionalParameters.put(OidcParameterNames.ID_TOKEN, "token123"); + } + return OAuth2AccessTokenResponse.withToken("accessToken123") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .additionalParameters(additionalParameters).build(); + }; + } + + private static OAuth2UserService createOauth2UserService() { + Map userAttributes = Collections.singletonMap("name", "spring"); + return request -> new DefaultOAuth2User( + Collections.singleton(new OAuth2UserAuthority(userAttributes)), + userAttributes, "name"); + } + + private static OAuth2UserService createOidcUserService() { + OidcIdToken idToken = new OidcIdToken("token123", Instant.now(), + Instant.now().plusSeconds(3600), + Collections.singletonMap(IdTokenClaimNames.SUB, "sub123")); + return request -> new DefaultOidcUser( + Collections.singleton(new OidcUserAuthority(idToken)), idToken); + } + + private static GrantedAuthoritiesMapper createGrantedAuthoritiesMapper() { + return authorities -> { + boolean isOidc = OidcUserAuthority.class + .isInstance(authorities.iterator().next()); + List mappedAuthorities = new ArrayList<>(authorities); + mappedAuthorities.add(new SimpleGrantedAuthority( + isOidc ? "ROLE_OIDC_USER" : "ROLE_OAUTH2_USER")); + return mappedAuthorities; + }; + } + +}