Use OAuth2AuthorizedClientRepository in filters and resolver

Fixes gh-5544
This commit is contained in:
Joe Grandja 2018-07-19 16:46:13 -04:00
parent 39e336136f
commit 9a144d742e
12 changed files with 215 additions and 87 deletions

View File

@ -20,8 +20,7 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportSelector; import org.springframework.context.annotation.ImportSelector;
import org.springframework.core.type.AnnotationMetadata; import org.springframework.core.type.AnnotationMetadata;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.method.support.HandlerMethodArgumentResolver;
@ -58,22 +57,21 @@ final class OAuth2ClientConfiguration {
@Configuration @Configuration
static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer { static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer {
private OAuth2AuthorizedClientService authorizedClientService; private OAuth2AuthorizedClientRepository authorizedClientRepository;
@Override @Override
public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) { public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
if (this.authorizedClientService != null) { if (this.authorizedClientRepository != null) {
OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver = OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver =
new OAuth2AuthorizedClientArgumentResolver( new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository);
new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService));
argumentResolvers.add(authorizedClientArgumentResolver); argumentResolvers.add(authorizedClientArgumentResolver);
} }
} }
@Autowired(required = false) @Autowired(required = false)
public void setAuthorizedClientService(List<OAuth2AuthorizedClientService> authorizedClientServices) { public void setAuthorizedClientRepository(List<OAuth2AuthorizedClientRepository> authorizedClientRepositories) {
if (authorizedClientServices.size() == 1) { if (authorizedClientRepositories.size() == 1) {
this.authorizedClientService = authorizedClientServices.get(0); this.authorizedClientRepository = authorizedClientRepositories.get(0);
} }
} }
} }

View File

@ -29,6 +29,7 @@ import org.springframework.security.oauth2.client.web.AuthorizationRequestReposi
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -63,7 +64,7 @@ import org.springframework.util.Assert;
* *
* <ul> * <ul>
* <li>{@link ClientRegistrationRepository} (required)</li> * <li>{@link ClientRegistrationRepository} (required)</li>
* <li>{@link OAuth2AuthorizedClientService} (optional)</li> * <li>{@link OAuth2AuthorizedClientRepository} (optional)</li>
* </ul> * </ul>
* *
* <h2>Shared Objects Used</h2> * <h2>Shared Objects Used</h2>
@ -72,7 +73,7 @@ import org.springframework.util.Assert;
* *
* <ul> * <ul>
* <li>{@link ClientRegistrationRepository}</li> * <li>{@link ClientRegistrationRepository}</li>
* <li>{@link OAuth2AuthorizedClientService}</li> * <li>{@link OAuth2AuthorizedClientRepository}</li>
* </ul> * </ul>
* *
* @author Joe Grandja * @author Joe Grandja
@ -80,7 +81,7 @@ import org.springframework.util.Assert;
* @see OAuth2AuthorizationRequestRedirectFilter * @see OAuth2AuthorizationRequestRedirectFilter
* @see OAuth2AuthorizationCodeGrantFilter * @see OAuth2AuthorizationCodeGrantFilter
* @see ClientRegistrationRepository * @see ClientRegistrationRepository
* @see OAuth2AuthorizedClientService * @see OAuth2AuthorizedClientRepository
* @see AbstractHttpConfigurer * @see AbstractHttpConfigurer
*/ */
public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> extends public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> extends
@ -100,6 +101,18 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
return this; return this;
} }
/**
* Sets the repository for authorized client(s).
*
* @param authorizedClientRepository the authorized client repository
* @return the {@link OAuth2ClientConfigurer} for further configuration
*/
public OAuth2ClientConfigurer<B> authorizedClientRepository(OAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository);
return this;
}
/** /**
* Sets the service for authorized client(s). * Sets the service for authorized client(s).
* *
@ -108,7 +121,7 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
*/ */
public OAuth2ClientConfigurer<B> authorizedClientService(OAuth2AuthorizedClientService authorizedClientService) { public OAuth2ClientConfigurer<B> authorizedClientService(OAuth2AuthorizedClientService authorizedClientService) {
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
this.getBuilder().setSharedObject(OAuth2AuthorizedClientService.class, authorizedClientService); this.authorizedClientRepository(new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService));
return this; return this;
} }
@ -309,8 +322,7 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter( OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder),
new AuthenticatedPrincipalOAuth2AuthorizedClientRepository( OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(builder),
OAuth2ClientConfigurerUtils.getAuthorizedClientService(builder)),
authenticationManager); authenticationManager);
if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository != null) { if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository != null) {

View File

@ -24,6 +24,8 @@ import org.springframework.security.config.annotation.web.configurers.AbstractHt
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import java.util.Map; import java.util.Map;
@ -61,14 +63,35 @@ final class OAuth2ClientConfigurerUtils {
return clientRegistrationRepositoryMap.values().iterator().next(); return clientRegistrationRepositoryMap.values().iterator().next();
} }
static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizedClientService getAuthorizedClientService(B builder) { static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizedClientRepository getAuthorizedClientRepository(B builder) {
OAuth2AuthorizedClientService authorizedClientService = builder.getSharedObject(OAuth2AuthorizedClientService.class); OAuth2AuthorizedClientRepository authorizedClientRepository = builder.getSharedObject(OAuth2AuthorizedClientRepository.class);
if (authorizedClientService == null) { if (authorizedClientRepository == null) {
authorizedClientService = getAuthorizedClientServiceBean(builder); authorizedClientRepository = getAuthorizedClientRepositoryBean(builder);
if (authorizedClientService == null) { if (authorizedClientRepository == null) {
authorizedClientService = new InMemoryOAuth2AuthorizedClientService(getClientRegistrationRepository(builder)); authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(
getAuthorizedClientService((builder)));
} }
builder.setSharedObject(OAuth2AuthorizedClientService.class, authorizedClientService); builder.setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository);
}
return authorizedClientRepository;
}
private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizedClientRepository getAuthorizedClientRepositoryBean(B builder) {
Map<String, OAuth2AuthorizedClientRepository> authorizedClientRepositoryMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
builder.getSharedObject(ApplicationContext.class), OAuth2AuthorizedClientRepository.class);
if (authorizedClientRepositoryMap.size() > 1) {
throw new NoUniqueBeanDefinitionException(OAuth2AuthorizedClientRepository.class, authorizedClientRepositoryMap.size(),
"Expected single matching bean of type '" + OAuth2AuthorizedClientRepository.class.getName() + "' but found " +
authorizedClientRepositoryMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(authorizedClientRepositoryMap.keySet()));
}
return (!authorizedClientRepositoryMap.isEmpty() ? authorizedClientRepositoryMap.values().iterator().next() : null);
}
private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizedClientService getAuthorizedClientService(B builder) {
OAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientServiceBean(builder);
if (authorizedClientService == null) {
authorizedClientService = new InMemoryOAuth2AuthorizedClientService(getClientRegistrationRepository(builder));
} }
return authorizedClientService; return authorizedClientService;
} }

View File

@ -42,9 +42,11 @@ import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserServ
import org.springframework.security.oauth2.client.userinfo.DelegatingOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.DelegatingOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter; import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
@ -92,7 +94,7 @@ import java.util.Map;
* *
* <ul> * <ul>
* <li>{@link ClientRegistrationRepository} (required)</li> * <li>{@link ClientRegistrationRepository} (required)</li>
* <li>{@link OAuth2AuthorizedClientService} (optional)</li> * <li>{@link OAuth2AuthorizedClientRepository} (optional)</li>
* <li>{@link GrantedAuthoritiesMapper} (optional)</li> * <li>{@link GrantedAuthoritiesMapper} (optional)</li>
* </ul> * </ul>
* *
@ -102,7 +104,7 @@ import java.util.Map;
* *
* <ul> * <ul>
* <li>{@link ClientRegistrationRepository}</li> * <li>{@link ClientRegistrationRepository}</li>
* <li>{@link OAuth2AuthorizedClientService}</li> * <li>{@link OAuth2AuthorizedClientRepository}</li>
* <li>{@link GrantedAuthoritiesMapper}</li> * <li>{@link GrantedAuthoritiesMapper}</li>
* <li>{@link DefaultLoginPageGeneratingFilter} - if {@link #loginPage(String)} is not configured * <li>{@link DefaultLoginPageGeneratingFilter} - if {@link #loginPage(String)} is not configured
* and {@code DefaultLoginPageGeneratingFilter} is available, than a default login page will be made available</li> * and {@code DefaultLoginPageGeneratingFilter} is available, than a default login page will be made available</li>
@ -115,6 +117,7 @@ import java.util.Map;
* @see OAuth2AuthorizationRequestRedirectFilter * @see OAuth2AuthorizationRequestRedirectFilter
* @see OAuth2LoginAuthenticationFilter * @see OAuth2LoginAuthenticationFilter
* @see ClientRegistrationRepository * @see ClientRegistrationRepository
* @see OAuth2AuthorizedClientRepository
* @see AbstractAuthenticationFilterConfigurer * @see AbstractAuthenticationFilterConfigurer
*/ */
public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> extends public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> extends
@ -139,6 +142,19 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
return this; return this;
} }
/**
* Sets the repository for authorized client(s).
*
* @since 5.1
* @param authorizedClientRepository the authorized client repository
* @return the {@link OAuth2LoginConfigurer} for further configuration
*/
public OAuth2LoginConfigurer<B> authorizedClientRepository(OAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository);
return this;
}
/** /**
* Sets the service for authorized client(s). * Sets the service for authorized client(s).
* *
@ -147,7 +163,7 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
*/ */
public OAuth2LoginConfigurer<B> authorizedClientService(OAuth2AuthorizedClientService authorizedClientService) { public OAuth2LoginConfigurer<B> authorizedClientService(OAuth2AuthorizedClientService authorizedClientService) {
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
this.getBuilder().setSharedObject(OAuth2AuthorizedClientService.class, authorizedClientService); this.authorizedClientRepository(new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService));
return this; return this;
} }
@ -400,7 +416,7 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
OAuth2LoginAuthenticationFilter authenticationFilter = OAuth2LoginAuthenticationFilter authenticationFilter =
new OAuth2LoginAuthenticationFilter( new OAuth2LoginAuthenticationFilter(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()),
OAuth2ClientConfigurerUtils.getAuthorizedClientService(this.getBuilder()), OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(this.getBuilder()),
this.loginProcessingUrl); this.loginProcessingUrl);
this.setAuthenticationFilter(authenticationFilter); this.setAuthenticationFilter(authenticationFilter);
super.loginProcessingUrl(this.loginProcessingUrl); super.loginProcessingUrl(this.loginProcessingUrl);

View File

@ -21,22 +21,27 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import javax.servlet.http.HttpServletRequest;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@ -57,18 +62,20 @@ public class OAuth2ClientConfigurationTests {
public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws Exception { public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws Exception {
String clientRegistrationId = "client1"; String clientRegistrationId = "client1";
String principalName = "user1"; String principalName = "user1";
TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
OAuth2AuthorizedClientService authorizedClientService = mock(OAuth2AuthorizedClientService.class); OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class);
when(authorizedClientService.loadAuthorizedClient(clientRegistrationId, principalName)).thenReturn(authorizedClient); when(authorizedClientRepository.loadAuthorizedClient(
eq(clientRegistrationId), eq(authentication), any(HttpServletRequest.class))).thenReturn(authorizedClient);
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class); OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
when(authorizedClient.getAccessToken()).thenReturn(accessToken); when(authorizedClient.getAccessToken()).thenReturn(accessToken);
OAuth2AuthorizedClientArgumentResolverConfig.AUTHORIZED_CLIENT_SERVICE = authorizedClientService; OAuth2AuthorizedClientArgumentResolverConfig.AUTHORIZED_CLIENT_REPOSITORY = authorizedClientRepository;
this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire(); this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire();
this.mockMvc.perform(get("/authorized-client").with(user(principalName))) this.mockMvc.perform(get("/authorized-client").with(authentication(authentication)))
.andExpect(status().isOk()) .andExpect(status().isOk())
.andExpect(content().string("resolved")); .andExpect(content().string("resolved"));
} }
@ -76,7 +83,7 @@ public class OAuth2ClientConfigurationTests {
@EnableWebMvc @EnableWebMvc
@EnableWebSecurity @EnableWebSecurity
static class OAuth2AuthorizedClientArgumentResolverConfig extends WebSecurityConfigurerAdapter { static class OAuth2AuthorizedClientArgumentResolverConfig extends WebSecurityConfigurerAdapter {
static OAuth2AuthorizedClientService AUTHORIZED_CLIENT_SERVICE; static OAuth2AuthorizedClientRepository AUTHORIZED_CLIENT_REPOSITORY;
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
@ -92,23 +99,23 @@ public class OAuth2ClientConfigurationTests {
} }
@Bean @Bean
public OAuth2AuthorizedClientService authorizedClientService() { public OAuth2AuthorizedClientRepository authorizedClientRepository() {
return AUTHORIZED_CLIENT_SERVICE; return AUTHORIZED_CLIENT_REPOSITORY;
} }
} }
// gh-5321 // gh-5321
@Test @Test
public void loadContextWhenOAuth2AuthorizedClientServiceRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() { public void loadContextWhenOAuth2AuthorizedClientRepositoryRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() {
assertThatThrownBy(() -> this.spring.register(OAuth2AuthorizedClientServiceRegisteredTwiceConfig.class).autowire()) assertThatThrownBy(() -> this.spring.register(OAuth2AuthorizedClientRepositoryRegisteredTwiceConfig.class).autowire())
.hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class) .hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class)
.hasMessageContaining("Expected single matching bean of type '" + OAuth2AuthorizedClientService.class.getName() + .hasMessageContaining("Expected single matching bean of type '" + OAuth2AuthorizedClientRepository.class.getName() +
"' but found 2: authorizedClientService1,authorizedClientService2"); "' but found 2: authorizedClientRepository1,authorizedClientRepository2");
} }
@EnableWebMvc @EnableWebMvc
@EnableWebSecurity @EnableWebSecurity
static class OAuth2AuthorizedClientServiceRegisteredTwiceConfig extends WebSecurityConfigurerAdapter { static class OAuth2AuthorizedClientRepositoryRegisteredTwiceConfig extends WebSecurityConfigurerAdapter {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
@ -127,13 +134,13 @@ public class OAuth2ClientConfigurationTests {
} }
@Bean @Bean
public OAuth2AuthorizedClientService authorizedClientService1() { public OAuth2AuthorizedClientRepository authorizedClientRepository1() {
return mock(OAuth2AuthorizedClientService.class); return mock(OAuth2AuthorizedClientRepository.class);
} }
@Bean @Bean
public OAuth2AuthorizedClientService authorizedClientService2() { public OAuth2AuthorizedClientRepository authorizedClientRepository2() {
return mock(OAuth2AuthorizedClientService.class); return mock(OAuth2AuthorizedClientRepository.class);
} }
} }
@ -194,8 +201,8 @@ public class OAuth2ClientConfigurationTests {
} }
@Bean @Bean
public OAuth2AuthorizedClientService authorizedClientService() { public OAuth2AuthorizedClientRepository authorizedClientRepository() {
return mock(OAuth2AuthorizedClientService.class); return mock(OAuth2AuthorizedClientRepository.class);
} }
} }
} }

View File

@ -23,6 +23,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession; import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
@ -36,10 +37,12 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
@ -61,6 +64,7 @@ import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
@ -76,6 +80,8 @@ public class OAuth2ClientConfigurerTests {
private static OAuth2AuthorizedClientService authorizedClientService; private static OAuth2AuthorizedClientService authorizedClientService;
private static OAuth2AuthorizedClientRepository authorizedClientRepository;
private static OAuth2AuthorizationRequestResolver authorizationRequestResolver; private static OAuth2AuthorizationRequestResolver authorizationRequestResolver;
private static OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient; private static OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
@ -107,6 +113,7 @@ public class OAuth2ClientConfigurerTests {
.build(); .build();
clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1); clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1);
authorizedClientService = new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository); authorizedClientService = new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository);
authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver( authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
clientRegistrationRepository, "/oauth2/authorization"); clientRegistrationRepository, "/oauth2/authorization");
@ -153,17 +160,18 @@ public class OAuth2ClientConfigurerTests {
MockHttpSession session = (MockHttpSession) request.getSession(); MockHttpSession session = (MockHttpSession) request.getSession();
String principalName = "user1"; String principalName = "user1";
TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
this.mockMvc.perform(get("/client-1") this.mockMvc.perform(get("/client-1")
.param(OAuth2ParameterNames.CODE, "code") .param(OAuth2ParameterNames.CODE, "code")
.param(OAuth2ParameterNames.STATE, "state") .param(OAuth2ParameterNames.STATE, "state")
.with(user(principalName)) .with(authentication(authentication))
.session(session)) .session(session))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andExpect(redirectedUrl("http://localhost/client-1")); .andExpect(redirectedUrl("http://localhost/client-1"));
OAuth2AuthorizedClient authorizedClient = authorizedClientService.loadAuthorizedClient( OAuth2AuthorizedClient authorizedClient = authorizedClientRepository.loadAuthorizedClient(
this.registration1.getRegistrationId(), principalName); this.registration1.getRegistrationId(), authentication, request);
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
} }
@ -229,8 +237,8 @@ public class OAuth2ClientConfigurerTests {
} }
@Bean @Bean
public OAuth2AuthorizedClientService authorizedClientService() { public OAuth2AuthorizedClientRepository authorizedClientRepository() {
return authorizedClientService; return authorizedClientRepository;
} }
@RestController @RestController

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,12 +15,6 @@
*/ */
package org.springframework.security.oauth2.client.web; package org.springframework.security.oauth2.client.web;
import java.io.IOException;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
@ -43,6 +37,11 @@ import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
/** /**
* An implementation of an {@link AbstractAuthenticationProcessingFilter} for OAuth 2.0 Login. * An implementation of an {@link AbstractAuthenticationProcessingFilter} for OAuth 2.0 Login.
* *
@ -68,7 +67,7 @@ import org.springframework.util.MultiValueMap;
* </li> * </li>
* <li> * <li>
* Upon a successful authentication, an {@link OAuth2AuthenticationToken} is created (representing the End-User {@code Principal}) * Upon a successful authentication, an {@link OAuth2AuthenticationToken} is created (representing the End-User {@code Principal})
* and associated to the {@link OAuth2AuthorizedClient Authorized Client} using the {@link OAuth2AuthorizedClientService}. * and associated to the {@link OAuth2AuthorizedClient Authorized Client} using the {@link OAuth2AuthorizedClientRepository}.
* </li> * </li>
* <li> * <li>
* Finally, the {@link OAuth2AuthenticationToken} is returned and ultimately stored * Finally, the {@link OAuth2AuthenticationToken} is returned and ultimately stored
@ -88,7 +87,7 @@ import org.springframework.util.MultiValueMap;
* @see OAuth2AuthorizationRequestRedirectFilter * @see OAuth2AuthorizationRequestRedirectFilter
* @see ClientRegistrationRepository * @see ClientRegistrationRepository
* @see OAuth2AuthorizedClient * @see OAuth2AuthorizedClient
* @see OAuth2AuthorizedClientService * @see OAuth2AuthorizedClientRepository
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a> * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.2">Section 4.1.2 Authorization Response</a> * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.2">Section 4.1.2 Authorization Response</a>
*/ */
@ -100,7 +99,7 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found"; private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
private static final String CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE = "client_registration_not_found"; private static final String CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE = "client_registration_not_found";
private ClientRegistrationRepository clientRegistrationRepository; private ClientRegistrationRepository clientRegistrationRepository;
private OAuth2AuthorizedClientService authorizedClientService; private OAuth2AuthorizedClientRepository authorizedClientRepository;
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository(); new HttpSessionOAuth2AuthorizationRequestRepository();
@ -125,11 +124,26 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
public OAuth2LoginAuthenticationFilter(ClientRegistrationRepository clientRegistrationRepository, public OAuth2LoginAuthenticationFilter(ClientRegistrationRepository clientRegistrationRepository,
OAuth2AuthorizedClientService authorizedClientService, OAuth2AuthorizedClientService authorizedClientService,
String filterProcessesUrl) { String filterProcessesUrl) {
this(clientRegistrationRepository,
new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService), filterProcessesUrl);
}
/**
* Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided parameters.
*
* @since 5.1
* @param clientRegistrationRepository the repository of client registrations
* @param authorizedClientRepository the authorized client repository
* @param filterProcessesUrl the {@code URI} where this {@code Filter} will process the authentication requests
*/
public OAuth2LoginAuthenticationFilter(ClientRegistrationRepository clientRegistrationRepository,
OAuth2AuthorizedClientRepository authorizedClientRepository,
String filterProcessesUrl) {
super(filterProcessesUrl); super(filterProcessesUrl);
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository; this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientService = authorizedClientService; this.authorizedClientRepository = authorizedClientRepository;
} }
@Override @Override
@ -176,7 +190,7 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
authenticationResult.getAccessToken(), authenticationResult.getAccessToken(),
authenticationResult.getRefreshToken()); authenticationResult.getRefreshToken());
this.authorizedClientService.saveAuthorizedClient(authorizedClient, oauth2Authentication); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, oauth2Authentication, request, response);
return oauth2Authentication; return oauth2Authentication;
} }

View File

@ -70,10 +70,12 @@ public class OAuth2LoginAuthenticationFilterTests {
private ClientRegistration registration2; private ClientRegistration registration2;
private String principalName1 = "principal-1"; private String principalName1 = "principal-1";
private ClientRegistrationRepository clientRegistrationRepository; private ClientRegistrationRepository clientRegistrationRepository;
private OAuth2AuthorizedClientRepository authorizedClientRepository;
private OAuth2AuthorizedClientService authorizedClientService; private OAuth2AuthorizedClientService authorizedClientService;
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository; private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
private AuthenticationFailureHandler failureHandler; private AuthenticationFailureHandler failureHandler;
private AuthenticationManager authenticationManager; private AuthenticationManager authenticationManager;
private OAuth2LoginAuthenticationToken loginAuthentication;
private OAuth2LoginAuthenticationFilter filter; private OAuth2LoginAuthenticationFilter filter;
@Before @Before
@ -107,11 +109,12 @@ public class OAuth2LoginAuthenticationFilterTests {
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository( this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
this.registration1, this.registration2); this.registration1, this.registration2);
this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository); this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService);
this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
this.failureHandler = mock(AuthenticationFailureHandler.class); this.failureHandler = mock(AuthenticationFailureHandler.class);
this.authenticationManager = mock(AuthenticationManager.class); this.authenticationManager = mock(AuthenticationManager.class);
this.filter = spy(new OAuth2LoginAuthenticationFilter( this.filter = spy(new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository,
this.clientRegistrationRepository, this.authorizedClientService)); this.authorizedClientRepository, OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI));
this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository); this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
this.filter.setAuthenticationFailureHandler(this.failureHandler); this.filter.setAuthenticationFailureHandler(this.failureHandler);
this.filter.setAuthenticationManager(this.authenticationManager); this.filter.setAuthenticationManager(this.authenticationManager);
@ -129,9 +132,16 @@ public class OAuth2LoginAuthenticationFilterTests {
.isInstanceOf(IllegalArgumentException.class); .isInstanceOf(IllegalArgumentException.class);
} }
@Test
public void constructorWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository,
(OAuth2AuthorizedClientRepository) null, OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI))
.isInstanceOf(IllegalArgumentException.class);
}
@Test @Test
public void constructorWhenFilterProcessesUrlIsNullThenThrowIllegalArgumentException() { public void constructorWhenFilterProcessesUrlIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, this.authorizedClientService, null)) assertThatThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, this.authorizedClientRepository, null))
.isInstanceOf(IllegalArgumentException.class); .isInstanceOf(IllegalArgumentException.class);
} }
@ -276,8 +286,8 @@ public class OAuth2LoginAuthenticationFilterTests {
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
this.registration1.getRegistrationId(), this.principalName1); this.registration1.getRegistrationId(), this.loginAuthentication, request);
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1); assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName1); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName1);
@ -289,7 +299,7 @@ public class OAuth2LoginAuthenticationFilterTests {
public void doFilterWhenCustomFilterProcessesUrlThenFilterProcesses() throws Exception { public void doFilterWhenCustomFilterProcessesUrlThenFilterProcesses() throws Exception {
String filterProcessesUrl = "/login/oauth2/custom/*"; String filterProcessesUrl = "/login/oauth2/custom/*";
this.filter = spy(new OAuth2LoginAuthenticationFilter( this.filter = spy(new OAuth2LoginAuthenticationFilter(
this.clientRegistrationRepository, this.authorizedClientService, filterProcessesUrl)); this.clientRegistrationRepository, this.authorizedClientRepository, filterProcessesUrl));
this.filter.setAuthenticationManager(this.authenticationManager); this.filter.setAuthenticationManager(this.authenticationManager);
String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId(); String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId();
@ -324,13 +334,15 @@ public class OAuth2LoginAuthenticationFilterTests {
private void setUpAuthenticationResult(ClientRegistration registration) { private void setUpAuthenticationResult(ClientRegistration registration) {
OAuth2User user = mock(OAuth2User.class); OAuth2User user = mock(OAuth2User.class);
when(user.getName()).thenReturn(this.principalName1); when(user.getName()).thenReturn(this.principalName1);
OAuth2LoginAuthenticationToken loginAuthentication = mock(OAuth2LoginAuthenticationToken.class); this.loginAuthentication = mock(OAuth2LoginAuthenticationToken.class);
when(loginAuthentication.getPrincipal()).thenReturn(user); when(this.loginAuthentication.getPrincipal()).thenReturn(user);
when(loginAuthentication.getAuthorities()).thenReturn(AuthorityUtils.createAuthorityList("ROLE_USER")); when(this.loginAuthentication.getName()).thenReturn(this.principalName1);
when(loginAuthentication.getClientRegistration()).thenReturn(registration); when(this.loginAuthentication.getAuthorities()).thenReturn(AuthorityUtils.createAuthorityList("ROLE_USER"));
when(loginAuthentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class)); when(this.loginAuthentication.getClientRegistration()).thenReturn(registration);
when(loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class)); when(this.loginAuthentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class));
when(loginAuthentication.getRefreshToken()).thenReturn(mock(OAuth2RefreshToken.class)); when(this.loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(loginAuthentication); when(this.loginAuthentication.getRefreshToken()).thenReturn(mock(OAuth2RefreshToken.class));
when(this.loginAuthentication.isAuthenticated()).thenReturn(true);
when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(this.loginAuthentication);
} }
} }

View File

@ -22,24 +22,29 @@ import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc; import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc;
import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Import;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession; import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
@ -57,6 +62,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
@ -78,7 +84,7 @@ public class OAuth2AuthorizationCodeGrantApplicationTests {
private ClientRegistrationRepository clientRegistrationRepository; private ClientRegistrationRepository clientRegistrationRepository;
@Autowired @Autowired
private OAuth2AuthorizedClientService authorizedClientService; private OAuth2AuthorizedClientRepository authorizedClientRepository;
@Autowired @Autowired
private MockMvc mockMvc; private MockMvc mockMvc;
@ -116,18 +122,19 @@ public class OAuth2AuthorizationCodeGrantApplicationTests {
MockHttpSession session = (MockHttpSession) request.getSession(); MockHttpSession session = (MockHttpSession) request.getSession();
String principalName = "user"; String principalName = "user";
TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
// Authorization Response // Authorization Response
this.mockMvc.perform(get("/github-repos") this.mockMvc.perform(get("/github-repos")
.param(OAuth2ParameterNames.CODE, "code") .param(OAuth2ParameterNames.CODE, "code")
.param(OAuth2ParameterNames.STATE, "state") .param(OAuth2ParameterNames.STATE, "state")
.with(user(principalName)) .with(authentication(authentication))
.session(session)) .session(session))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andExpect(redirectedUrl("http://localhost/github-repos")); .andExpect(redirectedUrl("http://localhost/github-repos"));
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
registration.getRegistrationId(), principalName); registration.getRegistrationId(), authentication, request);
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
} }
@ -164,5 +171,15 @@ public class OAuth2AuthorizationCodeGrantApplicationTests {
@ComponentScan(basePackages = "sample.web") @ComponentScan(basePackages = "sample.web")
@Import(WebClientConfig.class) @Import(WebClientConfig.class)
public static class SpringBootApplicationTestConfig { public static class SpringBootApplicationTestConfig {
@Bean
public OAuth2AuthorizedClientService authorizedClientService(ClientRegistrationRepository clientRegistrationRepository) {
return new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository);
}
@Bean
public OAuth2AuthorizedClientRepository authorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) {
return new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
}
} }
} }

View File

@ -22,6 +22,9 @@ import org.springframework.security.config.annotation.web.configuration.WebSecur
import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.provisioning.InMemoryUserDetailsManager;
/** /**
@ -43,6 +46,11 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
.authorizationCodeGrant(); .authorizationCodeGrant();
} }
@Bean
public OAuth2AuthorizedClientRepository authorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) {
return new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
}
@Bean @Bean
public UserDetailsService userDetailsService() { public UserDetailsService userDetailsService() {
UserDetails userDetails = User.withDefaultPasswordEncoder() UserDetails userDetails = User.withDefaultPasswordEncoder()

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -45,7 +45,9 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter; import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@ -403,12 +405,14 @@ public class OAuth2LoginApplicationTests {
@ComponentScan(basePackages = "sample.web") @ComponentScan(basePackages = "sample.web")
public static class SpringBootApplicationTestConfig { public static class SpringBootApplicationTestConfig {
@Autowired @Bean
private ClientRegistrationRepository clientRegistrationRepository; public OAuth2AuthorizedClientService authorizedClientService(ClientRegistrationRepository clientRegistrationRepository) {
return new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository);
}
@Bean @Bean
public OAuth2AuthorizedClientService authorizedClientService() { public OAuth2AuthorizedClientRepository authorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) {
return new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository); return new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
} }
} }
} }

View File

@ -17,6 +17,10 @@ package sample;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
/** /**
* @author Joe Grandja * @author Joe Grandja
@ -27,4 +31,9 @@ public class OAuth2LoginApplication {
public static void main(String[] args) { public static void main(String[] args) {
SpringApplication.run(OAuth2LoginApplication.class, args); SpringApplication.run(OAuth2LoginApplication.class, args);
} }
@Bean
public OAuth2AuthorizedClientRepository authorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) {
return new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
}
} }