From 371221d7297a073c6248ea1b4e4f7c4495a714c4 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 27 Jun 2018 17:06:47 -0400 Subject: [PATCH] Support anonymous Principal for OAuth2AuthorizedClient Fixes gh-5064 --- .../OAuth2ClientConfiguration.java | 4 +- .../oauth2/client/OAuth2ClientConfigurer.java | 8 +- ...cipalOAuth2AuthorizedClientRepository.java | 104 +++++++ ...ssionOAuth2AuthorizedClientRepository.java | 89 ++++++ .../OAuth2AuthorizationCodeGrantFilter.java | 30 +- .../web/OAuth2AuthorizedClientRepository.java | 84 ++++++ ...Auth2AuthorizedClientArgumentResolver.java | 25 +- ...OAuth2AuthorizedClientRepositoryTests.java | 122 ++++++++ ...OAuth2AuthorizedClientRepositoryTests.java | 261 ++++++++++++++++++ ...uth2AuthorizationCodeGrantFilterTests.java | 69 ++++- ...AuthorizedClientArgumentResolverTests.java | 45 +-- 11 files changed, 777 insertions(+), 64 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepository.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepository.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientRepository.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepositoryTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 28ef942609..aa538d3445 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -21,6 +21,7 @@ import org.springframework.context.annotation.Import; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.util.ClassUtils; import org.springframework.web.method.support.HandlerMethodArgumentResolver; @@ -63,7 +64,8 @@ final class OAuth2ClientConfiguration { public void addArgumentResolvers(List argumentResolvers) { if (this.authorizedClientService != null) { OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver = - new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientService); + new OAuth2AuthorizedClientArgumentResolver( + new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService)); argumentResolvers.add(authorizedClientArgumentResolver); } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java index e7e6b47a80..5e85ff4bf7 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java @@ -24,6 +24,7 @@ import org.springframework.security.oauth2.client.endpoint.NimbusAuthorizationCo import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter; @@ -287,9 +288,10 @@ public final class OAuth2ClientConfigurer> exte AuthenticationManager authenticationManager = builder.getSharedObject(AuthenticationManager.class); OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter( - OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), - OAuth2ClientConfigurerUtils.getAuthorizedClientService(builder), - authenticationManager); + OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), + new AuthenticatedPrincipalOAuth2AuthorizedClientRepository( + OAuth2ClientConfigurerUtils.getAuthorizedClientService(builder)), + authenticationManager); if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository != null) { authorizationCodeGrantFilter.setAuthorizationRequestRepository( diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepository.java new file mode 100644 index 0000000000..5e4f9a59e7 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepository.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.security.authentication.AuthenticationTrustResolver; +import org.springframework.security.authentication.AuthenticationTrustResolverImpl; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * An implementation of an {@link OAuth2AuthorizedClientRepository} that + * delegates to the provided {@link OAuth2AuthorizedClientService} if the current + * {@code Principal} is authenticated, otherwise, + * to the default (or provided) {@link OAuth2AuthorizedClientRepository} + * if the current request is unauthenticated (or anonymous). + * The default {@code OAuth2AuthorizedClientRepository} is {@link HttpSessionOAuth2AuthorizedClientRepository}. + * + * @author Joe Grandja + * @since 5.1 + * @see OAuth2AuthorizedClientRepository + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizedClientService + * @see HttpSessionOAuth2AuthorizedClientRepository + */ +public final class AuthenticatedPrincipalOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository { + private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl(); + private final OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository = new HttpSessionOAuth2AuthorizedClientRepository(); + + /** + * Constructs a {@code AuthenticatedPrincipalOAuth2AuthorizedClientRepository} using the provided parameters. + * + * @param authorizedClientService the authorized client service + */ + public AuthenticatedPrincipalOAuth2AuthorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) { + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + this.authorizedClientService = authorizedClientService; + } + + /** + * Sets the {@link OAuth2AuthorizedClientRepository} used for requests that are unauthenticated (or anonymous). + * The default is {@link HttpSessionOAuth2AuthorizedClientRepository}. + * + * @param anonymousAuthorizedClientRepository the repository used for requests that are unauthenticated (or anonymous) + */ + public final void setAnonymousAuthorizedClientRepository(OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository) { + Assert.notNull(anonymousAuthorizedClientRepository, "anonymousAuthorizedClientRepository cannot be null"); + this.anonymousAuthorizedClientRepository = anonymousAuthorizedClientRepository; + } + + @Override + public T loadAuthorizedClient(String clientRegistrationId, Authentication principal, + HttpServletRequest request) { + if (this.isPrincipalAuthenticated(principal)) { + return this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()); + } else { + return this.anonymousAuthorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, request); + } + } + + @Override + public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { + if (this.isPrincipalAuthenticated(principal)) { + this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + } else { + this.anonymousAuthorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, request, response); + } + } + + @Override + public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { + if (this.isPrincipalAuthenticated(principal)) { + this.authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName()); + } else { + this.anonymousAuthorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, request, response); + } + } + + private boolean isPrincipalAuthenticated(Authentication authentication) { + return authentication != null && + !this.authenticationTrustResolver.isAnonymous(authentication) && + authentication.isAuthenticated(); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepository.java new file mode 100644 index 0000000000..ce76392f35 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepository.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; +import java.util.HashMap; +import java.util.Map; + +/** + * An implementation of an {@link OAuth2AuthorizedClientRepository} that stores + * {@link OAuth2AuthorizedClient}'s in the {@code HttpSession}. + * + * @author Joe Grandja + * @since 5.1 + * @see OAuth2AuthorizedClientRepository + * @see OAuth2AuthorizedClient + */ +public final class HttpSessionOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository { + private static final String DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME = + HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"; + private final String sessionAttributeName = DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME; + + @SuppressWarnings("unchecked") + @Override + public T loadAuthorizedClient(String clientRegistrationId, Authentication principal, + HttpServletRequest request) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.notNull(request, "request cannot be null"); + return (T) this.getAuthorizedClients(request).get(clientRegistrationId); + } + + @Override + public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(request, "request cannot be null"); + Assert.notNull(response, "response cannot be null"); + Map authorizedClients = this.getAuthorizedClients(request); + authorizedClients.put(authorizedClient.getClientRegistration().getRegistrationId(), authorizedClient); + request.getSession().setAttribute(this.sessionAttributeName, authorizedClients); + } + + @Override + public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.notNull(request, "request cannot be null"); + Map authorizedClients = this.getAuthorizedClients(request); + if (!authorizedClients.isEmpty()) { + if (authorizedClients.remove(clientRegistrationId) != null) { + if (!authorizedClients.isEmpty()) { + request.getSession().setAttribute(this.sessionAttributeName, authorizedClients); + } else { + request.getSession().removeAttribute(this.sessionAttributeName); + } + } + } + } + + @SuppressWarnings("unchecked") + private Map getAuthorizedClients(HttpServletRequest request) { + HttpSession session = request.getSession(false); + Map authorizedClients = session == null ? null : + (Map) session.getAttribute(this.sessionAttributeName); + if (authorizedClients == null) { + authorizedClients = new HashMap<>(); + } + return authorizedClients; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java index eac55375a8..90497366de 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java @@ -15,19 +15,11 @@ */ package org.springframework.security.oauth2.client.web; -import java.io.IOException; - -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -51,6 +43,12 @@ import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.util.UriComponentsBuilder; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; + /** * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, * which handles the processing of the OAuth 2.0 Authorization Response. @@ -74,7 +72,7 @@ import org.springframework.web.util.UriComponentsBuilder; * Upon a successful authentication, an {@link OAuth2AuthorizedClient Authorized Client} is created by associating the * {@link OAuth2AuthorizationCodeAuthenticationToken#getClientRegistration() client} to the * {@link OAuth2AuthorizationCodeAuthenticationToken#getAccessToken() access token} and current {@code Principal} - * and saving it via the {@link OAuth2AuthorizedClientService}. + * and saving it via the {@link OAuth2AuthorizedClientRepository}. * * * @@ -88,13 +86,13 @@ import org.springframework.web.util.UriComponentsBuilder; * @see OAuth2AuthorizationRequestRedirectFilter * @see ClientRegistrationRepository * @see OAuth2AuthorizedClient - * @see OAuth2AuthorizedClientService + * @see OAuth2AuthorizedClientRepository * @see Section 4.1 Authorization Code Grant * @see Section 4.1.2 Authorization Response */ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { private final ClientRegistrationRepository clientRegistrationRepository; - private final OAuth2AuthorizedClientService authorizedClientService; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; private final AuthenticationManager authenticationManager; private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); @@ -106,17 +104,17 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { * Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided parameters. * * @param clientRegistrationRepository the repository of client registrations - * @param authorizedClientService the authorized client service + * @param authorizedClientRepository the authorized client repository * @param authenticationManager the authentication manager */ public OAuth2AuthorizationCodeGrantFilter(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientService authorizedClientService, + OAuth2AuthorizedClientRepository authorizedClientRepository, AuthenticationManager authenticationManager) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); - Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); Assert.notNull(authenticationManager, "authenticationManager cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientService = authorizedClientService; + this.authorizedClientRepository = authorizedClientRepository; this.authenticationManager = authenticationManager; } @@ -201,7 +199,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { authenticationResult.getAccessToken(), authenticationResult.getRefreshToken()); - this.authorizedClientService.saveAuthorizedClient(authorizedClient, currentAuthentication); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, response); String redirectUrl = authorizationResponse.getRedirectUri(); SavedRequest savedRequest = this.requestCache.getRequest(request, response); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientRepository.java new file mode 100644 index 0000000000..b18c785861 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientRepository.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.OAuth2AccessToken; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * Implementations of this interface are responsible for the persistence + * of {@link OAuth2AuthorizedClient Authorized Client(s)} between requests. + * + *

+ * The primary purpose of an {@link OAuth2AuthorizedClient Authorized Client} + * is to associate an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential + * to a {@link OAuth2AuthorizedClient#getClientRegistration() Client} and Resource Owner, + * who is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal} + * that originally granted the authorization. + * + * @author Joe Grandja + * @since 5.1 + * @see OAuth2AuthorizedClient + * @see ClientRegistration + * @see Authentication + * @see OAuth2AccessToken + */ +public interface OAuth2AuthorizedClientRepository { + + /** + * Returns the {@link OAuth2AuthorizedClient} associated to the + * provided client registration identifier and End-User {@link Authentication} (Resource Owner) + * or {@code null} if not available. + * + * @param clientRegistrationId the identifier for the client's registration + * @param principal the End-User {@link Authentication} (Resource Owner) + * @param request the {@code HttpServletRequest} + * @param a type of OAuth2AuthorizedClient + * @return the {@link OAuth2AuthorizedClient} or {@code null} if not available + */ + T loadAuthorizedClient(String clientRegistrationId, Authentication principal, + HttpServletRequest request); + + /** + * Saves the {@link OAuth2AuthorizedClient} associating it to + * the provided End-User {@link Authentication} (Resource Owner). + * + * @param authorizedClient the authorized client + * @param principal the End-User {@link Authentication} (Resource Owner) + * @param request the {@code HttpServletRequest} + * @param response the {@code HttpServletResponse} + */ + void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest request, HttpServletResponse response); + + /** + * Removes the {@link OAuth2AuthorizedClient} associated to the + * provided client registration identifier and End-User {@link Authentication} (Resource Owner). + * + * @param clientRegistrationId the identifier for the client's registration + * @param principal the End-User {@link Authentication} (Resource Owner) + * @param request the {@code HttpServletRequest} + * @param response the {@code HttpServletResponse} + */ + void removeAuthorizedClient(String clientRegistrationId, Authentication principal, + HttpServletRequest request, HttpServletResponse response); + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index e5c0fd1b95..67daf2955b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -23,9 +23,9 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.bind.support.WebDataBinderFactory; @@ -33,6 +33,8 @@ import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.method.support.ModelAndViewContainer; +import javax.servlet.http.HttpServletRequest; + /** * An implementation of a {@link HandlerMethodArgumentResolver} that is capable * of resolving a method parameter to an argument value of type {@link OAuth2AuthorizedClient}. @@ -54,16 +56,16 @@ import org.springframework.web.method.support.ModelAndViewContainer; * @see RegisteredOAuth2AuthorizedClient */ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver { - private final OAuth2AuthorizedClientService authorizedClientService; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. * - * @param authorizedClientService the authorized client service + * @param authorizedClientRepository the authorized client repository */ - public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientService authorizedClientService) { - Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); - this.authorizedClientService = authorizedClientService; + public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.authorizedClientRepository = authorizedClientRepository; } @Override @@ -98,15 +100,8 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth "It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); } - if (principal == null) { - // An Authentication is required given that an OAuth2AuthorizedClient is associated to a Principal - throw new IllegalStateException("Unable to resolve the Authorized Client with registration identifier \"" + - clientRegistrationId + "\". An \"authenticated\" or \"unauthenticated\" session is required. " + - "To allow for unauthenticated access, ensure HttpSecurity.anonymous() is configured."); - } - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( - clientRegistrationId, principal.getName()); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + clientRegistrationId, principal, webRequest.getNativeRequest(HttpServletRequest.class)); if (authorizedClient == null) { throw new ClientAuthorizationRequiredException(clientRegistrationId); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests.java new file mode 100644 index 0000000000..e24c1e74d9 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link AuthenticatedPrincipalOAuth2AuthorizedClientRepository}. + * + * @author Joe Grandja + */ +public class AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests { + private String registrationId = "registrationId"; + private String principalName = "principalName"; + private OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository; + private AuthenticatedPrincipalOAuth2AuthorizedClientRepository authorizedClientRepository; + private MockHttpServletRequest request; + private MockHttpServletResponse response; + + @Before + public void setup() { + this.authorizedClientService = mock(OAuth2AuthorizedClientService.class); + this.anonymousAuthorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService); + this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(this.anonymousAuthorizedClientRepository); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + } + + @Test + public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setAuthorizedClientRepositoryWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void loadAuthorizedClientWhenAuthenticatedPrincipalThenLoadFromService() { + Authentication authentication = this.createAuthenticatedPrincipal(); + this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.request); + verify(this.authorizedClientService).loadAuthorizedClient(this.registrationId, this.principalName); + } + + @Test + public void loadAuthorizedClientWhenAnonymousPrincipalThenLoadFromAnonymousRepository() { + Authentication authentication = this.createAnonymousPrincipal(); + this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.request); + verify(this.anonymousAuthorizedClientRepository).loadAuthorizedClient(this.registrationId, authentication, this.request); + } + + @Test + public void saveAuthorizedClientWhenAuthenticatedPrincipalThenSaveToService() { + Authentication authentication = this.createAuthenticatedPrincipal(); + OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.request, this.response); + verify(this.authorizedClientService).saveAuthorizedClient(authorizedClient, authentication); + } + + @Test + public void saveAuthorizedClientWhenAnonymousPrincipalThenSaveToAnonymousRepository() { + Authentication authentication = this.createAnonymousPrincipal(); + OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.request, this.response); + verify(this.anonymousAuthorizedClientRepository).saveAuthorizedClient(authorizedClient, authentication, this.request, this.response); + } + + @Test + public void removeAuthorizedClientWhenAuthenticatedPrincipalThenRemoveFromService() { + Authentication authentication = this.createAuthenticatedPrincipal(); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.request, this.response); + verify(this.authorizedClientService).removeAuthorizedClient(this.registrationId, this.principalName); + } + + @Test + public void removeAuthorizedClientWhenAnonymousPrincipalThenRemoveFromAnonymousRepository() { + Authentication authentication = this.createAnonymousPrincipal(); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.request, this.response); + verify(this.anonymousAuthorizedClientRepository).removeAuthorizedClient(this.registrationId, authentication, this.request, this.response); + } + + private Authentication createAuthenticatedPrincipal() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken(this.principalName, "password"); + authentication.setAuthenticated(true); + return authentication; + } + + private Authentication createAnonymousPrincipal() { + return new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepositoryTests.java new file mode 100644 index 0000000000..1578dd2cdf --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepositoryTests.java @@ -0,0 +1,261 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; + +import javax.servlet.http.HttpSession; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link HttpSessionOAuth2AuthorizedClientRepository}. + * + * @author Joe Grandja + */ +public class HttpSessionOAuth2AuthorizedClientRepositoryTests { + private String registrationId1 = "registration-1"; + private String registrationId2 = "registration-2"; + private String principalName1 = "principalName-1"; + + private ClientRegistration registration1 = ClientRegistration.withRegistrationId(this.registrationId1) + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}") + .scope("user") + .authorizationUri("https://provider.com/oauth2/authorize") + .tokenUri("https://provider.com/oauth2/token") + .userInfoUri("https://provider.com/oauth2/user") + .userNameAttributeName("id") + .clientName("client-1") + .build(); + + private ClientRegistration registration2 = ClientRegistration.withRegistrationId(this.registrationId2) + .clientId("client-2") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}") + .scope("openid", "profile", "email") + .authorizationUri("https://provider.com/oauth2/authorize") + .tokenUri("https://provider.com/oauth2/token") + .userInfoUri("https://provider.com/oauth2/userinfo") + .jwkSetUri("https://provider.com/oauth2/keys") + .clientName("client-2") + .build(); + + private HttpSessionOAuth2AuthorizedClientRepository authorizedClientRepository = + new HttpSessionOAuth2AuthorizedClientRepository(); + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + @Before + public void setup() { + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + } + + @Test + public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(null, null, this.request)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void loadAuthorizedClientWhenPrincipalNameIsNullThenExceptionNotThrown() { + this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, this.request); + } + + @Test + public void loadAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() { + OAuth2AuthorizedClient authorizedClient = + this.authorizedClientRepository.loadAuthorizedClient("registration-not-found", null, this.request); + assertThat(authorizedClient).isNull(); + } + + @Test + public void loadAuthorizedClientWhenSavedThenReturnAuthorizedClient() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response); + + OAuth2AuthorizedClient loadedAuthorizedClient = + this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, this.request); + assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); + } + + @Test + public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(null, null, this.request, this.response)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void saveAuthorizedClientWhenAuthenticationIsNullThenExceptionNotThrown() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response); + } + + @Test + public void saveAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, null, this.response)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void saveAuthorizedClientWhenResponseIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void saveAuthorizedClientWhenSavedThenSavedToSession() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response); + + HttpSession session = this.request.getSession(false); + assertThat(session).isNotNull(); + + @SuppressWarnings("unchecked") + Map authorizedClients = (Map) + session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"); + assertThat(authorizedClients).isNotEmpty(); + assertThat(authorizedClients).hasSize(1); + assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient); + } + + @Test + public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient( + null, null, this.request, this.response)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void removeAuthorizedClientWhenPrincipalNameIsNullThenExceptionNotThrown() { + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.request, this.response); + } + + @Test + public void removeAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId1, null, null, this.response)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void removeAuthorizedClientWhenResponseIsNullThenExceptionNotThrown() { + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.request, null); + } + + @Test + public void removeAuthorizedClientWhenNotSavedThenSessionNotCreated() { + this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId2, null, this.request, this.response); + assertThat(this.request.getSession(false)).isNull(); + } + + @Test + public void removeAuthorizedClientWhenClient1SavedAndClient2RemovedThenClient1NotRemoved() { + OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient( + this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.request, this.response); + + // Remove registrationId2 (never added so is not removed either) + this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId2, null, this.request, this.response); + + OAuth2AuthorizedClient loadedAuthorizedClient1 = this.authorizedClientRepository.loadAuthorizedClient( + this.registrationId1, null, this.request); + assertThat(loadedAuthorizedClient1).isNotNull(); + assertThat(loadedAuthorizedClient1).isSameAs(authorizedClient1); + } + + @Test + public void removeAuthorizedClientWhenSavedThenRemoved() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + this.registrationId2, null, this.request); + assertThat(loadedAuthorizedClient).isSameAs(authorizedClient); + this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId2, null, this.request, this.response); + loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + this.registrationId2, null, this.request); + assertThat(loadedAuthorizedClient).isNull(); + } + + @Test + public void removeAuthorizedClientWhenSavedThenRemovedFromSession() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + this.registrationId1, null, this.request); + assertThat(loadedAuthorizedClient).isSameAs(authorizedClient); + this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId1, null, this.request, this.response); + + HttpSession session = this.request.getSession(false); + assertThat(session).isNotNull(); + assertThat(session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS")).isNull(); + } + + @Test + public void removeAuthorizedClientWhenClient1Client2SavedAndClient1RemovedThenClient2NotRemoved() { + OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient( + this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.request, this.response); + + OAuth2AuthorizedClient authorizedClient2 = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient2, null, this.request, this.response); + + this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId1, null, this.request, this.response); + + OAuth2AuthorizedClient loadedAuthorizedClient2 = this.authorizedClientRepository.loadAuthorizedClient( + this.registrationId2, null, this.request); + assertThat(loadedAuthorizedClient2).isNotNull(); + assertThat(loadedAuthorizedClient2).isSameAs(authorizedClient2); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java index 3093adc02b..e9b615e98f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java @@ -15,6 +15,7 @@ */ package org.springframework.security.oauth2.client.web; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -23,9 +24,11 @@ import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; @@ -51,6 +54,7 @@ import org.springframework.security.web.savedrequest.RequestCache; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; import java.util.HashMap; import java.util.Map; @@ -71,12 +75,13 @@ public class OAuth2AuthorizationCodeGrantFilterTests { private String principalName1 = "principal-1"; private ClientRegistrationRepository clientRegistrationRepository; private OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientRepository authorizedClientRepository; private AuthenticationManager authenticationManager; private AuthorizationRequestRepository authorizationRequestRepository; private OAuth2AuthorizationCodeGrantFilter filter; @Before - public void setUp() { + public void setup() { this.registration1 = ClientRegistration.withRegistrationId("registration-1") .clientId("client-1") .clientSecret("secret") @@ -92,32 +97,39 @@ public class OAuth2AuthorizationCodeGrantFilterTests { .build(); this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1); this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository); + this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService); this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); this.authenticationManager = mock(AuthenticationManager.class); this.filter = spy(new OAuth2AuthorizationCodeGrantFilter( - this.clientRegistrationRepository, this.authorizedClientService, this.authenticationManager)); + this.clientRegistrationRepository, this.authorizedClientRepository, this.authenticationManager)); this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository); - + TestingAuthenticationToken authentication = new TestingAuthenticationToken(this.principalName1, "password"); + authentication.setAuthenticated(true); SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); - securityContext.setAuthentication(new TestingAuthenticationToken(this.principalName1, "password")); + securityContext.setAuthentication(authentication); SecurityContextHolder.setContext(securityContext); } + @After + public void cleanup() { + SecurityContextHolder.clearContext(); + } + @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(null, this.authorizedClientService, this.authenticationManager)) + assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(null, this.authorizedClientRepository, this.authenticationManager)) .isInstanceOf(IllegalArgumentException.class); } @Test - public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { + public void constructorWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, null, this.authenticationManager)) .isInstanceOf(IllegalArgumentException.class); } @Test public void constructorWhenAuthenticationManagerIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, this.authorizedClientService, null)) + assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, this.authorizedClientRepository, null)) .isInstanceOf(IllegalArgumentException.class); } @@ -218,7 +230,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { } @Test - public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSaved() throws Exception { + public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSavedToService() throws Exception { String requestUri = "/callback/client-1"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); @@ -285,6 +297,47 @@ public class OAuth2AuthorizationCodeGrantFilterTests { assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request"); } + @Test + public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception { + AnonymousAuthenticationToken anonymousPrincipal = + new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(anonymousPrincipal); + SecurityContextHolder.setContext(securityContext); + + String requestUri = "/callback/client-1"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.setUpAuthorizationRequest(request, response, this.registration1); + this.setUpAuthenticationResult(this.registration1); + + this.filter.doFilter(request, response, filterChain); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + this.registration1.getRegistrationId(), anonymousPrincipal, request); + assertThat(authorizedClient).isNotNull(); + + assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(anonymousPrincipal.getName()); + assertThat(authorizedClient.getAccessToken()).isNotNull(); + + HttpSession session = request.getSession(false); + assertThat(session).isNotNull(); + + @SuppressWarnings("unchecked") + Map authorizedClients = (Map) + session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"); + assertThat(authorizedClients).isNotEmpty(); + assertThat(authorizedClients).hasSize(1); + assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient); + } + private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, ClientRegistration registration) { Map additionalParameters = new HashMap<>(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 65fd59725c..d67d527da0 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -15,19 +15,23 @@ */ package org.springframework.security.oauth2.client.web.method.annotation; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.springframework.core.MethodParameter; +import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.ReflectionUtils; +import org.springframework.web.context.request.ServletWebRequest; +import javax.servlet.http.HttpServletRequest; import java.lang.reflect.Method; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -43,21 +47,29 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class OAuth2AuthorizedClientArgumentResolverTests { - private OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientArgumentResolver argumentResolver; private OAuth2AuthorizedClient authorizedClient; + private MockHttpServletRequest request; @Before - public void setUp() { - this.authorizedClientService = mock(OAuth2AuthorizedClientService.class); - this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientService); + public void setup() { + this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository); this.authorizedClient = mock(OAuth2AuthorizedClient.class); - when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(this.authorizedClient); + this.request = new MockHttpServletRequest(); + when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) + .thenReturn(this.authorizedClient); SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); securityContext.setAuthentication(mock(Authentication.class)); SecurityContextHolder.setContext(securityContext); } + @After + public void cleanup() { + SecurityContextHolder.clearContext(); + } + @Test public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)) @@ -104,31 +116,22 @@ public class OAuth2AuthorizedClientArgumentResolverTests { securityContext.setAuthentication(authentication); SecurityContextHolder.setContext(securityContext); MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); - this.argumentResolver.resolveArgument(methodParameter, null, null, null); - } - - @Test - public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenThrowIllegalStateException() throws Exception { - SecurityContextHolder.clearContext(); - MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); - assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null)) - .isInstanceOf(IllegalStateException.class) - .hasMessage("Unable to resolve the Authorized Client with registration identifier \"client1\". " + - "An \"authenticated\" or \"unauthenticated\" session is required. " + - "To allow for unauthenticated access, ensure HttpSecurity.anonymous() is configured."); + this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request), null); } @Test public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() throws Exception { MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); - assertThat(this.argumentResolver.resolveArgument(methodParameter, null, null, null)).isSameAs(this.authorizedClient); + assertThat(this.argumentResolver.resolveArgument( + methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient); } @Test public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() throws Exception { - when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(null); + when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) + .thenReturn(null); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); - assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null)) + assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request), null)) .isInstanceOf(ClientAuthorizationRequiredException.class); }