From 5c14e48b18ef98dcf3fc51b0869be6b9768c97d6 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 5 Oct 2017 15:00:35 -0400 Subject: [PATCH] Add OAuth2UserAuthenticationProvider Moved logic from AuthorizationCodeAuthenticationProvider to OAuth2UserAuthenticationProvider (new) related to loading user attributes via OAuth2UserService. This re-factor is part of the work required for Issue gh-4513 --- ...ionCodeAuthenticationFilterConfigurer.java | 18 +++- ...thorizationCodeAuthenticationProvider.java | 56 +---------- .../OAuth2UserAuthenticationProvider.java | 99 +++++++++++++++++++ .../OAuth2UserAuthenticationToken.java | 8 +- ...ionCodeAuthenticationProcessingFilter.java | 86 +++++++++++++++- .../OidcUserAuthenticationToken.java | 11 +++ ...deAuthenticationProcessingFilterTests.java | 46 +++++---- .../META-INF/oauth2-clients-defaults.yml | 2 + 8 files changed, 244 insertions(+), 82 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationProvider.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java index ce0e4f670a..bd58eccfd3 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java @@ -23,6 +23,7 @@ import org.springframework.security.oauth2.client.authentication.AuthorizationCo import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticator; import org.springframework.security.oauth2.client.authentication.AuthorizationGrantAuthenticator; import org.springframework.security.oauth2.client.authentication.DelegatingAuthorizationGrantAuthenticator; +import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationProvider; import org.springframework.security.oauth2.client.authentication.jwt.JwtDecoderRegistry; import org.springframework.security.oauth2.client.authentication.jwt.nimbus.NimbusJwtDecoderRegistry; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -128,13 +129,20 @@ final class AuthorizationCodeAuthenticationFilterConfigurerauthorization code credential with the authorization server's Token Endpoint * and if valid, exchanging it for an access token credential and optionally an * id token credential (for OpenID Connect Authorization Code Flow). - * Additionally, it will also obtain the end-user's (resource owner) attributes from the UserInfo Endpoint - * (using the access token) and create a Principal in the form of an {@link OAuth2User} - * associating it with the returned {@link OAuth2UserAuthenticationToken}. * *

* The {@link AuthorizationCodeAuthenticationProvider} uses an {@link AuthorizationGrantAuthenticator} * to authenticate the {@link AuthorizationCodeAuthenticationToken#getAuthorizationCode()} and ultimately * return an "Authorized Client" as an {@link OAuth2ClientAuthenticationToken}. * - *

- * It will then call {@link OAuth2UserService#loadUser(OAuth2ClientAuthenticationToken)} - * to obtain the end-user's (resource owner) attributes in the form of an {@link OAuth2User}. - * - *

- * Finally, it will create an {@link OAuth2UserAuthenticationToken}, associating the {@link OAuth2User} - * and {@link OAuth2ClientAuthenticationToken} and return it to the {@link AuthenticationManager}, - * at which point the {@link OAuth2UserAuthenticationToken} is considered "authenticated". - * * @author Joe Grandja * @since 5.0 * @see AuthorizationCodeAuthenticationToken * @see OAuth2ClientAuthenticationToken * @see OidcClientAuthenticationToken - * @see OAuth2UserAuthenticationToken - * @see OidcUserAuthenticationToken * @see AuthorizationGrantAuthenticator - * @see OAuth2UserService - * @see OAuth2User - * @see OidcUser + * @see SecurityTokenRepository * @see Section 4.1 Authorization Code Grant Flow * @see Section 3.1 OpenID Connect Authorization Code Flow * @see Section 4.1.3 Access Token Request @@ -75,20 +50,15 @@ import java.util.Collection; public class AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { private final AuthorizationGrantAuthenticator authorizationCodeAuthenticator; private final SecurityTokenRepository accessTokenRepository; - private final OAuth2UserService userService; - private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); public AuthorizationCodeAuthenticationProvider( AuthorizationGrantAuthenticator authorizationCodeAuthenticator, - SecurityTokenRepository accessTokenRepository, - OAuth2UserService userService) { + SecurityTokenRepository accessTokenRepository) { Assert.notNull(authorizationCodeAuthenticator, "authorizationCodeAuthenticator cannot be null"); Assert.notNull(accessTokenRepository, "accessTokenRepository cannot be null"); - Assert.notNull(userService, "userService cannot be null"); this.authorizationCodeAuthenticator = authorizationCodeAuthenticator; this.accessTokenRepository = accessTokenRepository; - this.userService = userService; } @Override @@ -103,27 +73,7 @@ public class AuthorizationCodeAuthenticationProvider implements AuthenticationPr oauth2ClientAuthentication.getAccessToken(), oauth2ClientAuthentication.getClientRegistration()); - OAuth2User oauth2User = this.userService.loadUser(oauth2ClientAuthentication); - - Collection mappedAuthorities = - this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); - - OAuth2UserAuthenticationToken oauth2UserAuthentication; - if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) { - oauth2UserAuthentication = new OidcUserAuthenticationToken( - (OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)oauth2ClientAuthentication); - } else { - oauth2UserAuthentication = new OAuth2UserAuthenticationToken( - oauth2User, mappedAuthorities, oauth2ClientAuthentication); - } - oauth2UserAuthentication.setDetails(oauth2ClientAuthentication.getDetails()); - - return oauth2UserAuthentication; - } - - public final void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { - Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null"); - this.authoritiesMapper = authoritiesMapper; + return oauth2ClientAuthentication; } @Override diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationProvider.java new file mode 100644 index 0000000000..fa1c4627e7 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationProvider.java @@ -0,0 +1,99 @@ +/* + * Copyright 2012-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.authentication; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.oauth2.client.user.DefaultOAuth2UserService; +import org.springframework.security.oauth2.client.user.OAuth2UserService; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken; +import org.springframework.security.oauth2.oidc.client.authentication.OidcUserAuthenticationToken; +import org.springframework.security.oauth2.oidc.client.user.OidcUserService; +import org.springframework.security.oauth2.oidc.core.user.OidcUser; +import org.springframework.util.Assert; + +import java.util.Collection; + +/** + * An implementation of an {@link AuthenticationProvider} that is responsible + * for obtaining the user attributes of the End-User (resource owner) + * from the UserInfo Endpoint and creating a Principal + * in the form of an {@link OAuth2User}. + * + *

+ * The {@link OAuth2UserAuthenticationProvider} uses an {@link OAuth2UserService} + * for loading the {@link OAuth2User} and then associating it + * to the returned {@link OAuth2UserAuthenticationToken}. + * + * @author Joe Grandja + * @since 5.0 + * @see OAuth2UserAuthenticationToken + * @see OidcUserAuthenticationToken + * @see OAuth2ClientAuthenticationToken + * @see OidcClientAuthenticationToken + * @see OAuth2UserService + * @see DefaultOAuth2UserService + * @see OidcUserService + * @see OAuth2User + * @see OidcUser + */ +public class OAuth2UserAuthenticationProvider implements AuthenticationProvider { + private final OAuth2UserService userService; + private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); + + public OAuth2UserAuthenticationProvider(OAuth2UserService userService) { + Assert.notNull(userService, "userService cannot be null"); + this.userService = userService; + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OAuth2UserAuthenticationToken oauth2UserAuthentication = (OAuth2UserAuthenticationToken) authentication; + + OAuth2ClientAuthenticationToken oauth2ClientAuthentication = oauth2UserAuthentication.getClientAuthentication(); + + OAuth2User oauth2User = this.userService.loadUser(oauth2ClientAuthentication); + + Collection mappedAuthorities = + this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); + + OAuth2UserAuthenticationToken authenticationResult; + if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) { + authenticationResult = new OidcUserAuthenticationToken( + (OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)oauth2ClientAuthentication); + } else { + authenticationResult = new OAuth2UserAuthenticationToken( + oauth2User, mappedAuthorities, oauth2ClientAuthentication); + } + authenticationResult.setDetails(oauth2ClientAuthentication.getDetails()); + + return authenticationResult; + } + + public final void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { + Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null"); + this.authoritiesMapper = authoritiesMapper; + } + + @Override + public boolean supports(Class authentication) { + return OAuth2UserAuthenticationToken.class.isAssignableFrom(authentication); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationToken.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationToken.java index 92ae66009c..0671d99fc0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationToken.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationToken.java @@ -19,6 +19,7 @@ import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.SpringSecurityCoreVersion; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; @@ -42,14 +43,17 @@ public class OAuth2UserAuthenticationToken extends AbstractAuthenticationToken { private final OAuth2User principal; private final OAuth2ClientAuthenticationToken clientAuthentication; + public OAuth2UserAuthenticationToken(OAuth2ClientAuthenticationToken clientAuthentication) { + this(null, AuthorityUtils.NO_AUTHORITIES, clientAuthentication); + } + public OAuth2UserAuthenticationToken(OAuth2User principal, Collection authorities, OAuth2ClientAuthenticationToken clientAuthentication) { super(authorities); - Assert.notNull(principal, "principal cannot be null"); Assert.notNull(clientAuthentication, "clientAuthentication cannot be null"); this.principal = principal; this.clientAuthentication = clientAuthentication; - this.setAuthenticated(true); + this.setAuthenticated(principal != null); } @Override diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilter.java index 78d8de0478..5aaf439d6c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilter.java @@ -18,10 +18,14 @@ package org.springframework.security.oauth2.client.web; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationProvider; import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException; +import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.converter.AuthorizationCodeAuthorizationResponseAttributesConverter; import org.springframework.security.oauth2.client.web.converter.ErrorResponseAttributesConverter; @@ -30,6 +34,10 @@ import org.springframework.security.oauth2.core.endpoint.AuthorizationCodeAuthor import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes; import org.springframework.security.oauth2.core.endpoint.ErrorResponseAttributes; import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken; +import org.springframework.security.oauth2.oidc.client.authentication.OidcUserAuthenticationToken; +import org.springframework.security.oauth2.oidc.core.user.OidcUser; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -84,6 +92,7 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut private ClientRegistrationRepository clientRegistrationRepository; private RequestMatcher authorizationResponseMatcher = new AuthorizationResponseMatcher(); private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository(); + private final ClientRegistrationIdentifierStrategy providerIdentifierStrategy = new ProviderIdentifierStrategy(); public AuthorizationCodeAuthenticationProcessingFilter() { super(new AuthorizationResponseMatcher()); @@ -119,14 +128,27 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut AuthorizationCodeAuthorizationResponseAttributes authorizationCodeResponseAttributes = this.authorizationCodeResponseConverter.apply(request); - AuthorizationCodeAuthenticationToken authRequest = new AuthorizationCodeAuthenticationToken( + AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = new AuthorizationCodeAuthenticationToken( authorizationCodeResponseAttributes.getCode(), clientRegistration, matchingAuthorizationRequest); + authorizationCodeAuthentication.setDetails(this.authenticationDetailsSource.buildDetails(request)); - authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); + OAuth2ClientAuthenticationToken oauth2ClientAuthentication = + (OAuth2ClientAuthenticationToken)this.getAuthenticationManager().authenticate(authorizationCodeAuthentication); - Authentication authenticated = this.getAuthenticationManager().authenticate(authRequest); + OAuth2UserAuthenticationToken oauth2UserAuthentication; + if (this.authenticated() && this.authenticatedSameProviderAs(oauth2ClientAuthentication)) { + // Create a new user authentication (using same principal) + // but with a different client authentication association + oauth2UserAuthentication = (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication(); + oauth2UserAuthentication = this.createUserAuthentication(oauth2UserAuthentication, oauth2ClientAuthentication); + } else { + // Authenticate the user... the user needs to be authenticated + // before we can associate the client authentication to the user + oauth2UserAuthentication = (OAuth2UserAuthenticationToken)this.getAuthenticationManager().authenticate( + this.createUserAuthentication(oauth2ClientAuthentication)); + } - return authenticated; + return oauth2UserAuthentication; } public RequestMatcher getAuthorizationResponseMatcher() { @@ -182,6 +204,50 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut } } + private boolean authenticated() { + Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication(); + return currentAuthentication != null && + currentAuthentication instanceof OAuth2UserAuthenticationToken && + currentAuthentication.isAuthenticated(); + } + + private boolean authenticatedSameProviderAs(OAuth2ClientAuthenticationToken oauth2ClientAuthentication) { + OAuth2UserAuthenticationToken userAuthentication = + (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication(); + + String userProviderId = this.providerIdentifierStrategy.getIdentifier( + userAuthentication.getClientAuthentication().getClientRegistration()); + String clientProviderId = this.providerIdentifierStrategy.getIdentifier( + oauth2ClientAuthentication.getClientRegistration()); + + return userProviderId.equals(clientProviderId); + } + + private OAuth2UserAuthenticationToken createUserAuthentication(OAuth2ClientAuthenticationToken clientAuthentication) { + if (OidcClientAuthenticationToken.class.isAssignableFrom(clientAuthentication.getClass())) { + return new OidcUserAuthenticationToken((OidcClientAuthenticationToken)clientAuthentication); + } else { + return new OAuth2UserAuthenticationToken(clientAuthentication); + } + } + + private OAuth2UserAuthenticationToken createUserAuthentication( + OAuth2UserAuthenticationToken currentUserAuthentication, + OAuth2ClientAuthenticationToken newClientAuthentication) { + + if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) { + return new OidcUserAuthenticationToken( + (OidcUser) currentUserAuthentication.getPrincipal(), + currentUserAuthentication.getAuthorities(), + newClientAuthentication); + } else { + return new OAuth2UserAuthenticationToken( + (OAuth2User)currentUserAuthentication.getPrincipal(), + currentUserAuthentication.getAuthorities(), + newClientAuthentication); + } + } + private static class AuthorizationResponseMatcher implements RequestMatcher { @Override @@ -199,4 +265,16 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut StringUtils.hasText(request.getParameter(OAuth2Parameter.STATE)); } } + + private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy { + + @Override + public String getIdentifier(ClientRegistration clientRegistration) { + StringBuilder builder = new StringBuilder(); + builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]"); + builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]"); + builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]"); + return builder.toString(); + } + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/oidc/client/authentication/OidcUserAuthenticationToken.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/oidc/client/authentication/OidcUserAuthenticationToken.java index df98aca977..78c2979b47 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/oidc/client/authentication/OidcUserAuthenticationToken.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/oidc/client/authentication/OidcUserAuthenticationToken.java @@ -17,6 +17,8 @@ package org.springframework.security.oauth2.oidc.client.authentication; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken; import org.springframework.security.oauth2.oidc.core.user.OidcUser; @@ -38,8 +40,17 @@ import java.util.Collection; */ public class OidcUserAuthenticationToken extends OAuth2UserAuthenticationToken { + public OidcUserAuthenticationToken(OidcClientAuthenticationToken clientAuthentication) { + this(null, AuthorityUtils.NO_AUTHORITIES, clientAuthentication); + } + public OidcUserAuthenticationToken(OidcUser principal, Collection authorities, OidcClientAuthenticationToken clientAuthentication) { + this(principal, authorities, (OAuth2ClientAuthenticationToken)clientAuthentication); + } + + public OidcUserAuthenticationToken(OidcUser principal, Collection authorities, + OAuth2ClientAuthenticationToken clientAuthentication) { super(principal, authorities, clientAuthentication); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilterTests.java index 42b2309326..d36590bd95 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilterTests.java @@ -23,15 +23,20 @@ import org.mockito.Mockito; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; -import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException; +import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.core.AccessToken; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes; import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter; +import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; @@ -41,6 +46,8 @@ import javax.servlet.http.HttpServletResponse; import java.util.HashMap; import java.util.Map; +import static org.mockito.Mockito.mock; + /** * Tests {@link AuthorizationCodeAuthenticationProcessingFilter}. * @@ -58,7 +65,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI); request.setServletPath(requestURI); MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = Mockito.mock(FilterChain.class); + FilterChain filterChain = mock(FilterChain.class); filter.doFilter(request, response, filterChain); @@ -71,7 +78,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { ClientRegistration clientRegistration = TestUtil.githubClientRegistration(); AuthorizationCodeAuthenticationProcessingFilter filter = Mockito.spy(setupFilter(clientRegistration)); - AuthenticationFailureHandler failureHandler = Mockito.mock(AuthenticationFailureHandler.class); + AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class); filter.setAuthenticationFailureHandler(failureHandler); MockHttpServletRequest request = this.setupRequest(clientRegistration); @@ -79,7 +86,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { request.addParameter(OAuth2Parameter.ERROR, errorCode); request.addParameter(OAuth2Parameter.STATE, "some state"); MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = Mockito.mock(FilterChain.class); + FilterChain filterChain = mock(FilterChain.class); filter.doFilter(request, response, filterChain); @@ -90,14 +97,17 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { @Test public void doFilterWhenAuthorizationCodeSuccessResponseThenAuthenticationSuccessHandlerIsCalled() throws Exception { - TestingAuthenticationToken authentication = new TestingAuthenticationToken("joe", "password", "user", "admin"); - AuthenticationManager authenticationManager = Mockito.mock(AuthenticationManager.class); - Mockito.when(authenticationManager.authenticate(Matchers.any(Authentication.class))).thenReturn(authentication); - ClientRegistration clientRegistration = TestUtil.githubClientRegistration(); + OAuth2ClientAuthenticationToken clientAuthentication = new OAuth2ClientAuthenticationToken( + clientRegistration, mock(AccessToken.class)); + OAuth2UserAuthenticationToken userAuthentication = new OAuth2UserAuthenticationToken( + mock(OAuth2User.class), AuthorityUtils.createAuthorityList("ROLE_USER"), clientAuthentication); + SecurityContextHolder.getContext().setAuthentication(userAuthentication); + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + Mockito.when(authenticationManager.authenticate(Matchers.any(Authentication.class))).thenReturn(clientAuthentication); AuthorizationCodeAuthenticationProcessingFilter filter = Mockito.spy(setupFilter(authenticationManager, clientRegistration)); - AuthenticationSuccessHandler successHandler = Mockito.mock(AuthenticationSuccessHandler.class); + AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class); filter.setAuthenticationSuccessHandler(successHandler); AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository(); filter.setAuthorizationRequestRepository(authorizationRequestRepository); @@ -109,7 +119,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { request.addParameter(OAuth2Parameter.STATE, state); MockHttpServletResponse response = new MockHttpServletResponse(); setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state); - FilterChain filterChain = Mockito.mock(FilterChain.class); + FilterChain filterChain = mock(FilterChain.class); filter.doFilter(request, response, filterChain); @@ -118,7 +128,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { ArgumentCaptor authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class); Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class), authenticationArgCaptor.capture()); - Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(authentication); + Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(userAuthentication); } @Test @@ -126,7 +136,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { ClientRegistration clientRegistration = TestUtil.githubClientRegistration(); AuthorizationCodeAuthenticationProcessingFilter filter = Mockito.spy(setupFilter(clientRegistration)); - AuthenticationFailureHandler failureHandler = Mockito.mock(AuthenticationFailureHandler.class); + AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class); filter.setAuthenticationFailureHandler(failureHandler); MockHttpServletRequest request = this.setupRequest(clientRegistration); @@ -135,7 +145,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { request.addParameter(OAuth2Parameter.CODE, authCode); request.addParameter(OAuth2Parameter.STATE, state); MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = Mockito.mock(FilterChain.class); + FilterChain filterChain = mock(FilterChain.class); filter.doFilter(request, response, filterChain); @@ -147,7 +157,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { ClientRegistration clientRegistration = TestUtil.githubClientRegistration(); AuthorizationCodeAuthenticationProcessingFilter filter = Mockito.spy(setupFilter(clientRegistration)); - AuthenticationFailureHandler failureHandler = Mockito.mock(AuthenticationFailureHandler.class); + AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class); filter.setAuthenticationFailureHandler(failureHandler); AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository(); filter.setAuthorizationRequestRepository(authorizationRequestRepository); @@ -159,7 +169,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { request.addParameter(OAuth2Parameter.STATE, state); MockHttpServletResponse response = new MockHttpServletResponse(); setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, "some state"); - FilterChain filterChain = Mockito.mock(FilterChain.class); + FilterChain filterChain = mock(FilterChain.class); filter.doFilter(request, response, filterChain); @@ -171,7 +181,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { ClientRegistration clientRegistration = TestUtil.githubClientRegistration(); AuthorizationCodeAuthenticationProcessingFilter filter = Mockito.spy(setupFilter(clientRegistration)); - AuthenticationFailureHandler failureHandler = Mockito.mock(AuthenticationFailureHandler.class); + AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class); filter.setAuthenticationFailureHandler(failureHandler); AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository(); filter.setAuthorizationRequestRepository(authorizationRequestRepository); @@ -184,7 +194,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { request.addParameter(OAuth2Parameter.STATE, state); MockHttpServletResponse response = new MockHttpServletResponse(); setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state); - FilterChain filterChain = Mockito.mock(FilterChain.class); + FilterChain filterChain = mock(FilterChain.class); filter.doFilter(request, response, filterChain); @@ -209,7 +219,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { } private AuthorizationCodeAuthenticationProcessingFilter setupFilter(ClientRegistration... clientRegistrations) throws Exception { - AuthenticationManager authenticationManager = Mockito.mock(AuthenticationManager.class); + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); return setupFilter(authenticationManager, clientRegistrations); } diff --git a/samples/boot/oauth2login/src/main/resources/META-INF/oauth2-clients-defaults.yml b/samples/boot/oauth2login/src/main/resources/META-INF/oauth2-clients-defaults.yml index 32436d4c1e..ed5dbdefde 100644 --- a/samples/boot/oauth2login/src/main/resources/META-INF/oauth2-clients-defaults.yml +++ b/samples/boot/oauth2login/src/main/resources/META-INF/oauth2-clients-defaults.yml @@ -10,6 +10,7 @@ security: authorization-uri: "https://accounts.google.com/o/oauth2/v2/auth" token-uri: "https://www.googleapis.com/oauth2/v4/token" user-info-uri: "https://www.googleapis.com/oauth2/v3/userinfo" + user-name-attribute-name: "sub" jwk-set-uri: "https://www.googleapis.com/oauth2/v3/certs" client-name: Google github: @@ -38,3 +39,4 @@ security: redirect-uri: "{scheme}://{serverName}:{serverPort}{contextPath}/oauth2/authorize/code/{registrationId}" scope: openid, profile, email, address, phone client-name: Okta + user-name-attribute-name: "sub"