Support anonymous Principal for OAuth2AuthorizedClient

Fixes gh-5064
This commit is contained in:
Joe Grandja 2018-06-27 17:06:47 -04:00 committed by Rob Winch
parent 779597af2a
commit 371221d729
11 changed files with 777 additions and 64 deletions

View File

@ -21,6 +21,7 @@ import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportSelector;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
import org.springframework.util.ClassUtils;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
@ -63,7 +64,8 @@ final class OAuth2ClientConfiguration {
public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
if (this.authorizedClientService != null) {
OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver =
new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientService);
new OAuth2AuthorizedClientArgumentResolver(
new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService));
argumentResolvers.add(authorizedClientArgumentResolver);
}
}

View File

@ -24,6 +24,7 @@ import org.springframework.security.oauth2.client.endpoint.NimbusAuthorizationCo
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
@ -287,9 +288,10 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
AuthenticationManager authenticationManager = builder.getSharedObject(AuthenticationManager.class);
OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder),
OAuth2ClientConfigurerUtils.getAuthorizedClientService(builder),
authenticationManager);
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder),
new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(
OAuth2ClientConfigurerUtils.getAuthorizedClientService(builder)),
authenticationManager);
if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository != null) {
authorizationCodeGrantFilter.setAuthorizationRequestRepository(

View File

@ -0,0 +1,104 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.springframework.security.authentication.AuthenticationTrustResolver;
import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.util.Assert;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* An implementation of an {@link OAuth2AuthorizedClientRepository} that
* delegates to the provided {@link OAuth2AuthorizedClientService} if the current
* {@code Principal} is authenticated, otherwise,
* to the default (or provided) {@link OAuth2AuthorizedClientRepository}
* if the current request is unauthenticated (or anonymous).
* The default {@code OAuth2AuthorizedClientRepository} is {@link HttpSessionOAuth2AuthorizedClientRepository}.
*
* @author Joe Grandja
* @since 5.1
* @see OAuth2AuthorizedClientRepository
* @see OAuth2AuthorizedClient
* @see OAuth2AuthorizedClientService
* @see HttpSessionOAuth2AuthorizedClientRepository
*/
public final class AuthenticatedPrincipalOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository {
private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl();
private final OAuth2AuthorizedClientService authorizedClientService;
private OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository = new HttpSessionOAuth2AuthorizedClientRepository();
/**
* Constructs a {@code AuthenticatedPrincipalOAuth2AuthorizedClientRepository} using the provided parameters.
*
* @param authorizedClientService the authorized client service
*/
public AuthenticatedPrincipalOAuth2AuthorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) {
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
this.authorizedClientService = authorizedClientService;
}
/**
* Sets the {@link OAuth2AuthorizedClientRepository} used for requests that are unauthenticated (or anonymous).
* The default is {@link HttpSessionOAuth2AuthorizedClientRepository}.
*
* @param anonymousAuthorizedClientRepository the repository used for requests that are unauthenticated (or anonymous)
*/
public final void setAnonymousAuthorizedClientRepository(OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository) {
Assert.notNull(anonymousAuthorizedClientRepository, "anonymousAuthorizedClientRepository cannot be null");
this.anonymousAuthorizedClientRepository = anonymousAuthorizedClientRepository;
}
@Override
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, Authentication principal,
HttpServletRequest request) {
if (this.isPrincipalAuthenticated(principal)) {
return this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName());
} else {
return this.anonymousAuthorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, request);
}
}
@Override
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal,
HttpServletRequest request, HttpServletResponse response) {
if (this.isPrincipalAuthenticated(principal)) {
this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
} else {
this.anonymousAuthorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, request, response);
}
}
@Override
public void removeAuthorizedClient(String clientRegistrationId, Authentication principal,
HttpServletRequest request, HttpServletResponse response) {
if (this.isPrincipalAuthenticated(principal)) {
this.authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName());
} else {
this.anonymousAuthorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, request, response);
}
}
private boolean isPrincipalAuthenticated(Authentication authentication) {
return authentication != null &&
!this.authenticationTrustResolver.isAnonymous(authentication) &&
authentication.isAuthenticated();
}
}

View File

@ -0,0 +1,89 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.util.Assert;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.util.HashMap;
import java.util.Map;
/**
* An implementation of an {@link OAuth2AuthorizedClientRepository} that stores
* {@link OAuth2AuthorizedClient}'s in the {@code HttpSession}.
*
* @author Joe Grandja
* @since 5.1
* @see OAuth2AuthorizedClientRepository
* @see OAuth2AuthorizedClient
*/
public final class HttpSessionOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository {
private static final String DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME =
HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS";
private final String sessionAttributeName = DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME;
@SuppressWarnings("unchecked")
@Override
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, Authentication principal,
HttpServletRequest request) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.notNull(request, "request cannot be null");
return (T) this.getAuthorizedClients(request).get(clientRegistrationId);
}
@Override
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal,
HttpServletRequest request, HttpServletResponse response) {
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
Assert.notNull(request, "request cannot be null");
Assert.notNull(response, "response cannot be null");
Map<String, OAuth2AuthorizedClient> authorizedClients = this.getAuthorizedClients(request);
authorizedClients.put(authorizedClient.getClientRegistration().getRegistrationId(), authorizedClient);
request.getSession().setAttribute(this.sessionAttributeName, authorizedClients);
}
@Override
public void removeAuthorizedClient(String clientRegistrationId, Authentication principal,
HttpServletRequest request, HttpServletResponse response) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.notNull(request, "request cannot be null");
Map<String, OAuth2AuthorizedClient> authorizedClients = this.getAuthorizedClients(request);
if (!authorizedClients.isEmpty()) {
if (authorizedClients.remove(clientRegistrationId) != null) {
if (!authorizedClients.isEmpty()) {
request.getSession().setAttribute(this.sessionAttributeName, authorizedClients);
} else {
request.getSession().removeAttribute(this.sessionAttributeName);
}
}
}
}
@SuppressWarnings("unchecked")
private Map<String, OAuth2AuthorizedClient> getAuthorizedClients(HttpServletRequest request) {
HttpSession session = request.getSession(false);
Map<String, OAuth2AuthorizedClient> authorizedClients = session == null ? null :
(Map<String, OAuth2AuthorizedClient>) session.getAttribute(this.sessionAttributeName);
if (authorizedClients == null) {
authorizedClients = new HashMap<>();
}
return authorizedClients;
}
}

View File

@ -15,19 +15,11 @@
*/
package org.springframework.security.oauth2.client.web;
import java.io.IOException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
@ -51,6 +43,12 @@ import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
/**
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
* which handles the processing of the OAuth 2.0 Authorization Response.
@ -74,7 +72,7 @@ import org.springframework.web.util.UriComponentsBuilder;
* Upon a successful authentication, an {@link OAuth2AuthorizedClient Authorized Client} is created by associating the
* {@link OAuth2AuthorizationCodeAuthenticationToken#getClientRegistration() client} to the
* {@link OAuth2AuthorizationCodeAuthenticationToken#getAccessToken() access token} and current {@code Principal}
* and saving it via the {@link OAuth2AuthorizedClientService}.
* and saving it via the {@link OAuth2AuthorizedClientRepository}.
* </li>
* </ul>
*
@ -88,13 +86,13 @@ import org.springframework.web.util.UriComponentsBuilder;
* @see OAuth2AuthorizationRequestRedirectFilter
* @see ClientRegistrationRepository
* @see OAuth2AuthorizedClient
* @see OAuth2AuthorizedClientService
* @see OAuth2AuthorizedClientRepository
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.2">Section 4.1.2 Authorization Response</a>
*/
public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
private final ClientRegistrationRepository clientRegistrationRepository;
private final OAuth2AuthorizedClientService authorizedClientService;
private final OAuth2AuthorizedClientRepository authorizedClientRepository;
private final AuthenticationManager authenticationManager;
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
@ -106,17 +104,17 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
* Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided parameters.
*
* @param clientRegistrationRepository the repository of client registrations
* @param authorizedClientService the authorized client service
* @param authorizedClientRepository the authorized client repository
* @param authenticationManager the authentication manager
*/
public OAuth2AuthorizationCodeGrantFilter(ClientRegistrationRepository clientRegistrationRepository,
OAuth2AuthorizedClientService authorizedClientService,
OAuth2AuthorizedClientRepository authorizedClientRepository,
AuthenticationManager authenticationManager) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientService = authorizedClientService;
this.authorizedClientRepository = authorizedClientRepository;
this.authenticationManager = authenticationManager;
}
@ -201,7 +199,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
authenticationResult.getAccessToken(),
authenticationResult.getRefreshToken());
this.authorizedClientService.saveAuthorizedClient(authorizedClient, currentAuthentication);
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, response);
String redirectUrl = authorizationResponse.getRedirectUri();
SavedRequest savedRequest = this.requestCache.getRequest(request, response);

View File

@ -0,0 +1,84 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Implementations of this interface are responsible for the persistence
* of {@link OAuth2AuthorizedClient Authorized Client(s)} between requests.
*
* <p>
* The primary purpose of an {@link OAuth2AuthorizedClient Authorized Client}
* is to associate an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential
* to a {@link OAuth2AuthorizedClient#getClientRegistration() Client} and Resource Owner,
* who is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal}
* that originally granted the authorization.
*
* @author Joe Grandja
* @since 5.1
* @see OAuth2AuthorizedClient
* @see ClientRegistration
* @see Authentication
* @see OAuth2AccessToken
*/
public interface OAuth2AuthorizedClientRepository {
/**
* Returns the {@link OAuth2AuthorizedClient} associated to the
* provided client registration identifier and End-User {@link Authentication} (Resource Owner)
* or {@code null} if not available.
*
* @param clientRegistrationId the identifier for the client's registration
* @param principal the End-User {@link Authentication} (Resource Owner)
* @param request the {@code HttpServletRequest}
* @param <T> a type of OAuth2AuthorizedClient
* @return the {@link OAuth2AuthorizedClient} or {@code null} if not available
*/
<T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, Authentication principal,
HttpServletRequest request);
/**
* Saves the {@link OAuth2AuthorizedClient} associating it to
* the provided End-User {@link Authentication} (Resource Owner).
*
* @param authorizedClient the authorized client
* @param principal the End-User {@link Authentication} (Resource Owner)
* @param request the {@code HttpServletRequest}
* @param response the {@code HttpServletResponse}
*/
void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal,
HttpServletRequest request, HttpServletResponse response);
/**
* Removes the {@link OAuth2AuthorizedClient} associated to the
* provided client registration identifier and End-User {@link Authentication} (Resource Owner).
*
* @param clientRegistrationId the identifier for the client's registration
* @param principal the End-User {@link Authentication} (Resource Owner)
* @param request the {@code HttpServletRequest}
* @param response the {@code HttpServletResponse}
*/
void removeAuthorizedClient(String clientRegistrationId, Authentication principal,
HttpServletRequest request, HttpServletResponse response);
}

View File

@ -23,9 +23,9 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.support.WebDataBinderFactory;
@ -33,6 +33,8 @@ import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.method.support.ModelAndViewContainer;
import javax.servlet.http.HttpServletRequest;
/**
* An implementation of a {@link HandlerMethodArgumentResolver} that is capable
* of resolving a method parameter to an argument value of type {@link OAuth2AuthorizedClient}.
@ -54,16 +56,16 @@ import org.springframework.web.method.support.ModelAndViewContainer;
* @see RegisteredOAuth2AuthorizedClient
*/
public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
private final OAuth2AuthorizedClientService authorizedClientService;
private final OAuth2AuthorizedClientRepository authorizedClientRepository;
/**
* Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
*
* @param authorizedClientService the authorized client service
* @param authorizedClientRepository the authorized client repository
*/
public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientService authorizedClientService) {
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
this.authorizedClientService = authorizedClientService;
public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.authorizedClientRepository = authorizedClientRepository;
}
@Override
@ -98,15 +100,8 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
"It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
}
if (principal == null) {
// An Authentication is required given that an OAuth2AuthorizedClient is associated to a Principal
throw new IllegalStateException("Unable to resolve the Authorized Client with registration identifier \"" +
clientRegistrationId + "\". An \"authenticated\" or \"unauthenticated\" session is required. " +
"To allow for unauthenticated access, ensure HttpSecurity.anonymous() is configured.");
}
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
clientRegistrationId, principal.getName());
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
clientRegistrationId, principal, webRequest.getNativeRequest(HttpServletRequest.class));
if (authorizedClient == null) {
throw new ClientAuthorizationRequiredException(clientRegistrationId);
}

View File

@ -0,0 +1,122 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link AuthenticatedPrincipalOAuth2AuthorizedClientRepository}.
*
* @author Joe Grandja
*/
public class AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests {
private String registrationId = "registrationId";
private String principalName = "principalName";
private OAuth2AuthorizedClientService authorizedClientService;
private OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository;
private AuthenticatedPrincipalOAuth2AuthorizedClientRepository authorizedClientRepository;
private MockHttpServletRequest request;
private MockHttpServletResponse response;
@Before
public void setup() {
this.authorizedClientService = mock(OAuth2AuthorizedClientService.class);
this.anonymousAuthorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService);
this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(this.anonymousAuthorizedClientRepository);
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
}
@Test
public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setAuthorizedClientRepositoryWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void loadAuthorizedClientWhenAuthenticatedPrincipalThenLoadFromService() {
Authentication authentication = this.createAuthenticatedPrincipal();
this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.request);
verify(this.authorizedClientService).loadAuthorizedClient(this.registrationId, this.principalName);
}
@Test
public void loadAuthorizedClientWhenAnonymousPrincipalThenLoadFromAnonymousRepository() {
Authentication authentication = this.createAnonymousPrincipal();
this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.request);
verify(this.anonymousAuthorizedClientRepository).loadAuthorizedClient(this.registrationId, authentication, this.request);
}
@Test
public void saveAuthorizedClientWhenAuthenticatedPrincipalThenSaveToService() {
Authentication authentication = this.createAuthenticatedPrincipal();
OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class);
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.request, this.response);
verify(this.authorizedClientService).saveAuthorizedClient(authorizedClient, authentication);
}
@Test
public void saveAuthorizedClientWhenAnonymousPrincipalThenSaveToAnonymousRepository() {
Authentication authentication = this.createAnonymousPrincipal();
OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class);
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.request, this.response);
verify(this.anonymousAuthorizedClientRepository).saveAuthorizedClient(authorizedClient, authentication, this.request, this.response);
}
@Test
public void removeAuthorizedClientWhenAuthenticatedPrincipalThenRemoveFromService() {
Authentication authentication = this.createAuthenticatedPrincipal();
this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.request, this.response);
verify(this.authorizedClientService).removeAuthorizedClient(this.registrationId, this.principalName);
}
@Test
public void removeAuthorizedClientWhenAnonymousPrincipalThenRemoveFromAnonymousRepository() {
Authentication authentication = this.createAnonymousPrincipal();
this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.request, this.response);
verify(this.anonymousAuthorizedClientRepository).removeAuthorizedClient(this.registrationId, authentication, this.request, this.response);
}
private Authentication createAuthenticatedPrincipal() {
TestingAuthenticationToken authentication = new TestingAuthenticationToken(this.principalName, "password");
authentication.setAuthenticated(true);
return authentication;
}
private Authentication createAnonymousPrincipal() {
return new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
}
}

View File

@ -0,0 +1,261 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import javax.servlet.http.HttpSession;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link HttpSessionOAuth2AuthorizedClientRepository}.
*
* @author Joe Grandja
*/
public class HttpSessionOAuth2AuthorizedClientRepositoryTests {
private String registrationId1 = "registration-1";
private String registrationId2 = "registration-2";
private String principalName1 = "principalName-1";
private ClientRegistration registration1 = ClientRegistration.withRegistrationId(this.registrationId1)
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}")
.scope("user")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/user")
.userNameAttributeName("id")
.clientName("client-1")
.build();
private ClientRegistration registration2 = ClientRegistration.withRegistrationId(this.registrationId2)
.clientId("client-2")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}")
.scope("openid", "profile", "email")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/userinfo")
.jwkSetUri("https://provider.com/oauth2/keys")
.clientName("client-2")
.build();
private HttpSessionOAuth2AuthorizedClientRepository authorizedClientRepository =
new HttpSessionOAuth2AuthorizedClientRepository();
private MockHttpServletRequest request;
private MockHttpServletResponse response;
@Before
public void setup() {
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
}
@Test
public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(null, null, this.request))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void loadAuthorizedClientWhenPrincipalNameIsNullThenExceptionNotThrown() {
this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, this.request);
}
@Test
public void loadAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() {
OAuth2AuthorizedClient authorizedClient =
this.authorizedClientRepository.loadAuthorizedClient("registration-not-found", null, this.request);
assertThat(authorizedClient).isNull();
}
@Test
public void loadAuthorizedClientWhenSavedThenReturnAuthorizedClient() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration1, this.principalName1, mock(OAuth2AccessToken.class));
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response);
OAuth2AuthorizedClient loadedAuthorizedClient =
this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, this.request);
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
}
@Test
public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(null, null, this.request, this.response))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void saveAuthorizedClientWhenAuthenticationIsNullThenExceptionNotThrown() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response);
}
@Test
public void saveAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, null, this.response))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void saveAuthorizedClientWhenResponseIsNullThenThrowIllegalArgumentException() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void saveAuthorizedClientWhenSavedThenSavedToSession() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response);
HttpSession session = this.request.getSession(false);
assertThat(session).isNotNull();
@SuppressWarnings("unchecked")
Map<String, OAuth2AuthorizedClient> authorizedClients = (Map<String, OAuth2AuthorizedClient>)
session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS");
assertThat(authorizedClients).isNotEmpty();
assertThat(authorizedClients).hasSize(1);
assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient);
}
@Test
public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient(
null, null, this.request, this.response)).isInstanceOf(IllegalArgumentException.class);
}
@Test
public void removeAuthorizedClientWhenPrincipalNameIsNullThenExceptionNotThrown() {
this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.request, this.response);
}
@Test
public void removeAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient(
this.registrationId1, null, null, this.response)).isInstanceOf(IllegalArgumentException.class);
}
@Test
public void removeAuthorizedClientWhenResponseIsNullThenExceptionNotThrown() {
this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.request, null);
}
@Test
public void removeAuthorizedClientWhenNotSavedThenSessionNotCreated() {
this.authorizedClientRepository.removeAuthorizedClient(
this.registrationId2, null, this.request, this.response);
assertThat(this.request.getSession(false)).isNull();
}
@Test
public void removeAuthorizedClientWhenClient1SavedAndClient2RemovedThenClient1NotRemoved() {
OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient(
this.registration1, this.principalName1, mock(OAuth2AccessToken.class));
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.request, this.response);
// Remove registrationId2 (never added so is not removed either)
this.authorizedClientRepository.removeAuthorizedClient(
this.registrationId2, null, this.request, this.response);
OAuth2AuthorizedClient loadedAuthorizedClient1 = this.authorizedClientRepository.loadAuthorizedClient(
this.registrationId1, null, this.request);
assertThat(loadedAuthorizedClient1).isNotNull();
assertThat(loadedAuthorizedClient1).isSameAs(authorizedClient1);
}
@Test
public void removeAuthorizedClientWhenSavedThenRemoved() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
this.registrationId2, null, this.request);
assertThat(loadedAuthorizedClient).isSameAs(authorizedClient);
this.authorizedClientRepository.removeAuthorizedClient(
this.registrationId2, null, this.request, this.response);
loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
this.registrationId2, null, this.request);
assertThat(loadedAuthorizedClient).isNull();
}
@Test
public void removeAuthorizedClientWhenSavedThenRemovedFromSession() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration1, this.principalName1, mock(OAuth2AccessToken.class));
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
this.registrationId1, null, this.request);
assertThat(loadedAuthorizedClient).isSameAs(authorizedClient);
this.authorizedClientRepository.removeAuthorizedClient(
this.registrationId1, null, this.request, this.response);
HttpSession session = this.request.getSession(false);
assertThat(session).isNotNull();
assertThat(session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS")).isNull();
}
@Test
public void removeAuthorizedClientWhenClient1Client2SavedAndClient1RemovedThenClient2NotRemoved() {
OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient(
this.registration1, this.principalName1, mock(OAuth2AccessToken.class));
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.request, this.response);
OAuth2AuthorizedClient authorizedClient2 = new OAuth2AuthorizedClient(
this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient2, null, this.request, this.response);
this.authorizedClientRepository.removeAuthorizedClient(
this.registrationId1, null, this.request, this.response);
OAuth2AuthorizedClient loadedAuthorizedClient2 = this.authorizedClientRepository.loadAuthorizedClient(
this.registrationId2, null, this.request);
assertThat(loadedAuthorizedClient2).isNotNull();
assertThat(loadedAuthorizedClient2).isSameAs(authorizedClient2);
}
}

View File

@ -15,6 +15,7 @@
*/
package org.springframework.security.oauth2.client.web;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -23,9 +24,11 @@ import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
@ -51,6 +54,7 @@ import org.springframework.security.web.savedrequest.RequestCache;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.util.HashMap;
import java.util.Map;
@ -71,12 +75,13 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
private String principalName1 = "principal-1";
private ClientRegistrationRepository clientRegistrationRepository;
private OAuth2AuthorizedClientService authorizedClientService;
private OAuth2AuthorizedClientRepository authorizedClientRepository;
private AuthenticationManager authenticationManager;
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
private OAuth2AuthorizationCodeGrantFilter filter;
@Before
public void setUp() {
public void setup() {
this.registration1 = ClientRegistration.withRegistrationId("registration-1")
.clientId("client-1")
.clientSecret("secret")
@ -92,32 +97,39 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
.build();
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1);
this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService);
this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
this.authenticationManager = mock(AuthenticationManager.class);
this.filter = spy(new OAuth2AuthorizationCodeGrantFilter(
this.clientRegistrationRepository, this.authorizedClientService, this.authenticationManager));
this.clientRegistrationRepository, this.authorizedClientRepository, this.authenticationManager));
this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
TestingAuthenticationToken authentication = new TestingAuthenticationToken(this.principalName1, "password");
authentication.setAuthenticated(true);
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(new TestingAuthenticationToken(this.principalName1, "password"));
securityContext.setAuthentication(authentication);
SecurityContextHolder.setContext(securityContext);
}
@After
public void cleanup() {
SecurityContextHolder.clearContext();
}
@Test
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(null, this.authorizedClientService, this.authenticationManager))
assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(null, this.authorizedClientRepository, this.authenticationManager))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
public void constructorWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, null, this.authenticationManager))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenAuthenticationManagerIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, this.authorizedClientService, null))
assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, this.authorizedClientRepository, null))
.isInstanceOf(IllegalArgumentException.class);
}
@ -218,7 +230,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
}
@Test
public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSaved() throws Exception {
public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSavedToService() throws Exception {
String requestUri = "/callback/client-1";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
@ -285,6 +297,47 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request");
}
@Test
public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
AnonymousAuthenticationToken anonymousPrincipal =
new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(anonymousPrincipal);
SecurityContextHolder.setContext(securityContext);
String requestUri = "/callback/client-1";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
this.filter.doFilter(request, response, filterChain);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
this.registration1.getRegistrationId(), anonymousPrincipal, request);
assertThat(authorizedClient).isNotNull();
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(anonymousPrincipal.getName());
assertThat(authorizedClient.getAccessToken()).isNotNull();
HttpSession session = request.getSession(false);
assertThat(session).isNotNull();
@SuppressWarnings("unchecked")
Map<String, OAuth2AuthorizedClient> authorizedClients = (Map<String, OAuth2AuthorizedClient>)
session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS");
assertThat(authorizedClients).isNotEmpty();
assertThat(authorizedClients).hasSize(1);
assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient);
}
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
ClientRegistration registration) {
Map<String, Object> additionalParameters = new HashMap<>();

View File

@ -15,19 +15,23 @@
*/
package org.springframework.security.oauth2.client.web.method.annotation;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.core.MethodParameter;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.context.request.ServletWebRequest;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
@ -43,21 +47,29 @@ import static org.mockito.Mockito.when;
* @author Joe Grandja
*/
public class OAuth2AuthorizedClientArgumentResolverTests {
private OAuth2AuthorizedClientService authorizedClientService;
private OAuth2AuthorizedClientRepository authorizedClientRepository;
private OAuth2AuthorizedClientArgumentResolver argumentResolver;
private OAuth2AuthorizedClient authorizedClient;
private MockHttpServletRequest request;
@Before
public void setUp() {
this.authorizedClientService = mock(OAuth2AuthorizedClientService.class);
this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientService);
public void setup() {
this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository);
this.authorizedClient = mock(OAuth2AuthorizedClient.class);
when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(this.authorizedClient);
this.request = new MockHttpServletRequest();
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class)))
.thenReturn(this.authorizedClient);
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(mock(Authentication.class));
SecurityContextHolder.setContext(securityContext);
}
@After
public void cleanup() {
SecurityContextHolder.clearContext();
}
@Test
public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null))
@ -104,31 +116,22 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
securityContext.setAuthentication(authentication);
SecurityContextHolder.setContext(securityContext);
MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
this.argumentResolver.resolveArgument(methodParameter, null, null, null);
}
@Test
public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenThrowIllegalStateException() throws Exception {
SecurityContextHolder.clearContext();
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null))
.isInstanceOf(IllegalStateException.class)
.hasMessage("Unable to resolve the Authorized Client with registration identifier \"client1\". " +
"An \"authenticated\" or \"unauthenticated\" session is required. " +
"To allow for unauthenticated access, ensure HttpSecurity.anonymous() is configured.");
this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request), null);
}
@Test
public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() throws Exception {
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
assertThat(this.argumentResolver.resolveArgument(methodParameter, null, null, null)).isSameAs(this.authorizedClient);
assertThat(this.argumentResolver.resolveArgument(
methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient);
}
@Test
public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() throws Exception {
when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(null);
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class)))
.thenReturn(null);
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null))
assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request), null))
.isInstanceOf(ClientAuthorizationRequiredException.class);
}