From 951e64185bbbe6249c699cff0f9a361ed0cf9315 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 24 Jun 2020 08:36:09 -0400 Subject: [PATCH] Register OAuth2AuthorizedClientArgumentResolver for XML Config Closes gh-8669 --- .../http/AuthenticationConfigBuilder.java | 65 ++++++++++--- .../OAuth2ClientBeanDefinitionParser.java | 62 +++++-------- ...OAuth2ClientBeanDefinitionParserUtils.java | 79 ++++++++++++++++ ...uth2ClientWebMvcSecurityPostProcessor.java | 91 +++++++++++++++++++ .../http/OAuth2LoginBeanDefinitionParser.java | 76 ++++++---------- ...OAuth2ClientBeanDefinitionParserTests.java | 31 +++++++ .../OAuth2LoginBeanDefinitionParserTests.java | 46 +++++++++- ...Tests-AuthorizedClientArgumentResolver.xml | 52 +++++++++++ ...Tests-AuthorizedClientArgumentResolver.xml | 45 +++++++++ 9 files changed, 443 insertions(+), 104 deletions(-) create mode 100644 config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserUtils.java create mode 100644 config/src/main/java/org/springframework/security/config/http/OAuth2ClientWebMvcSecurityPostProcessor.java create mode 100644 config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-AuthorizedClientArgumentResolver.xml create mode 100644 config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-AuthorizedClientArgumentResolver.xml diff --git a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java index 6de6a2f711..3e82001c40 100644 --- a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java +++ b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java @@ -15,18 +15,8 @@ */ package org.springframework.security.config.http; -import java.security.SecureRandom; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.function.Function; -import javax.servlet.http.HttpServletRequest; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.w3c.dom.Element; - import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanReference; @@ -63,8 +53,18 @@ import org.springframework.security.web.authentication.www.BasicAuthenticationEn import org.springframework.security.web.authentication.www.BasicAuthenticationFilter; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; +import org.w3c.dom.Element; + +import javax.servlet.http.HttpServletRequest; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; import static org.springframework.security.config.http.SecurityFilters.ANONYMOUS_FILTER; import static org.springframework.security.config.http.SecurityFilters.BASIC_AUTH_FILTER; @@ -160,12 +160,16 @@ final class AuthenticationConfigBuilder { private String openIDLoginPage; + private boolean oauth2LoginEnabled; + private boolean defaultAuthorizedClientRepositoryRegistered; private String oauth2LoginFilterId; private BeanDefinition oauth2AuthorizationRequestRedirectFilter; private BeanDefinition oauth2LoginEntryPoint; private BeanReference oauth2LoginAuthenticationProviderRef; private BeanReference oauth2LoginOidcAuthenticationProviderRef; private BeanDefinition oauth2LoginLinks; + + private boolean oauth2ClientEnabled; private BeanDefinition authorizationRequestRedirectFilter; private BeanDefinition authorizationCodeGrantFilter; private BeanReference authorizationCodeAuthenticationProviderRef; @@ -196,8 +200,7 @@ final class AuthenticationConfigBuilder { createBasicFilter(authenticationManager); createBearerTokenAuthenticationFilter(authenticationManager); createFormLoginFilter(sessionStrategy, authenticationManager); - createOAuth2LoginFilter(sessionStrategy, authenticationManager); - createOAuth2ClientFilter(requestCache, authenticationManager); + createOAuth2ClientFilters(sessionStrategy, requestCache, authenticationManager); createOpenIDLoginFilter(sessionStrategy, authenticationManager); createX509Filter(authenticationManager); createJeeFilter(authenticationManager); @@ -274,15 +277,27 @@ final class AuthenticationConfigBuilder { } } + void createOAuth2ClientFilters(BeanReference sessionStrategy, BeanReference requestCache, + BeanReference authenticationManager) { + createOAuth2LoginFilter(sessionStrategy, authenticationManager); + createOAuth2ClientFilter(requestCache, authenticationManager); + registerOAuth2ClientPostProcessors(); + } + void createOAuth2LoginFilter(BeanReference sessionStrategy, BeanReference authManager) { Element oauth2LoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.OAUTH2_LOGIN); if (oauth2LoginElt == null) { return; } + this.oauth2LoginEnabled = true; OAuth2LoginBeanDefinitionParser parser = new OAuth2LoginBeanDefinitionParser(requestCache, portMapper, portResolver, sessionStrategy, allowSessionCreation); BeanDefinition oauth2LoginFilterBean = parser.parse(oauth2LoginElt, this.pc); + + BeanDefinition defaultAuthorizedClientRepository = parser.getDefaultAuthorizedClientRepository(); + registerDefaultAuthorizedClientRepositoryIfNecessary(defaultAuthorizedClientRepository); + oauth2LoginFilterBean.getPropertyValues().addPropertyValue("authenticationManager", authManager); // retrieve the other bean result @@ -319,11 +334,15 @@ final class AuthenticationConfigBuilder { if (oauth2ClientElt == null) { return; } + this.oauth2ClientEnabled = true; OAuth2ClientBeanDefinitionParser parser = new OAuth2ClientBeanDefinitionParser( requestCache, authenticationManager); parser.parse(oauth2ClientElt, this.pc); + BeanDefinition defaultAuthorizedClientRepository = parser.getDefaultAuthorizedClientRepository(); + registerDefaultAuthorizedClientRepositoryIfNecessary(defaultAuthorizedClientRepository); + this.authorizationRequestRedirectFilter = parser.getAuthorizationRequestRedirectFilter(); String authorizationRequestRedirectFilterId = pc.getReaderContext() .generateBeanName(this.authorizationRequestRedirectFilter); @@ -344,6 +363,28 @@ final class AuthenticationConfigBuilder { this.authorizationCodeAuthenticationProviderRef = new RuntimeBeanReference(authorizationCodeAuthenticationProviderId); } + void registerDefaultAuthorizedClientRepositoryIfNecessary(BeanDefinition defaultAuthorizedClientRepository) { + if (!this.defaultAuthorizedClientRepositoryRegistered && defaultAuthorizedClientRepository != null) { + String authorizedClientRepositoryId = pc.getReaderContext() + .generateBeanName(defaultAuthorizedClientRepository); + this.pc.registerBeanComponent(new BeanComponentDefinition( + defaultAuthorizedClientRepository, authorizedClientRepositoryId)); + this.defaultAuthorizedClientRepositoryRegistered = true; + } + } + + private void registerOAuth2ClientPostProcessors() { + if (!this.oauth2LoginEnabled && !this.oauth2ClientEnabled) { + return; + } + + boolean webmvcPresent = ClassUtils.isPresent("org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader()); + if (webmvcPresent) { + this.pc.getReaderContext().registerWithGeneratedName( + new RootBeanDefinition(OAuth2ClientWebMvcSecurityPostProcessor.class)); + } + } + void createOpenIDLoginFilter(BeanReference sessionStrategy, BeanReference authManager) { Element openIDLoginElt = DomUtils.getChildElementByTagName(httpElt, Elements.OPENID_LOGIN); diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java index 269143ede0..71cd14661c 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java @@ -23,27 +23,30 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.xml.BeanDefinitionParser; import org.springframework.beans.factory.xml.ParserContext; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; import org.w3c.dom.Element; +import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createAuthorizedClientRepository; +import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createDefaultAuthorizedClientRepository; +import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientRepository; +import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientService; +import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getClientRegistrationRepository; + /** * @author Joe Grandja * @since 5.3 */ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { private static final String ELT_AUTHORIZATION_CODE_GRANT = "authorization-code-grant"; - private static final String ATT_CLIENT_REGISTRATION_REPOSITORY_REF = "client-registration-repository-ref"; - private static final String ATT_AUTHORIZED_CLIENT_REPOSITORY_REF = "authorized-client-repository-ref"; - private static final String ATT_AUTHORIZED_CLIENT_SERVICE_REF = "authorized-client-service-ref"; private static final String ATT_AUTHORIZATION_REQUEST_REPOSITORY_REF = "authorization-request-repository-ref"; private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref"; private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref"; private final BeanReference requestCache; private final BeanReference authenticationManager; + private BeanDefinition defaultAuthorizedClientRepository; private BeanDefinition authorizationRequestRedirectFilter; private BeanDefinition authorizationCodeGrantFilter; private BeanDefinition authorizationCodeAuthenticationProvider; @@ -58,8 +61,16 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { Element authorizationCodeGrantElt = DomUtils.getChildElementByTagName(element, ELT_AUTHORIZATION_CODE_GRANT); BeanMetadataElement clientRegistrationRepository = getClientRegistrationRepository(element); - BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository( - element, clientRegistrationRepository); + BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element); + if (authorizedClientRepository == null) { + BeanMetadataElement authorizedClientService = getAuthorizedClientService(element); + if (authorizedClientService == null) { + this.defaultAuthorizedClientRepository = createDefaultAuthorizedClientRepository(clientRegistrationRepository); + authorizedClientRepository = this.defaultAuthorizedClientRepository; + } else { + authorizedClientRepository = createAuthorizedClientRepository(authorizedClientService); + } + } BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository( authorizationCodeGrantElt); @@ -95,41 +106,6 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { return null; } - private BeanMetadataElement getClientRegistrationRepository(Element element) { - BeanMetadataElement clientRegistrationRepository; - String clientRegistrationRepositoryRef = element.getAttribute(ATT_CLIENT_REGISTRATION_REPOSITORY_REF); - if (!StringUtils.isEmpty(clientRegistrationRepositoryRef)) { - clientRegistrationRepository = new RuntimeBeanReference(clientRegistrationRepositoryRef); - } else { - clientRegistrationRepository = new RuntimeBeanReference(ClientRegistrationRepository.class); - } - return clientRegistrationRepository; - } - - private BeanMetadataElement getAuthorizedClientRepository(Element element, - BeanMetadataElement clientRegistrationRepository) { - BeanMetadataElement authorizedClientRepository; - String authorizedClientRepositoryRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_REPOSITORY_REF); - if (!StringUtils.isEmpty(authorizedClientRepositoryRef)) { - authorizedClientRepository = new RuntimeBeanReference(authorizedClientRepositoryRef); - } else { - BeanMetadataElement authorizedClientService; - String authorizedClientServiceRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_SERVICE_REF); - if (!StringUtils.isEmpty(authorizedClientServiceRef)) { - authorizedClientService = new RuntimeBeanReference(authorizedClientServiceRef); - } else { - authorizedClientService = BeanDefinitionBuilder - .rootBeanDefinition( - "org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService") - .addConstructorArgValue(clientRegistrationRepository).getBeanDefinition(); - } - authorizedClientRepository = BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository") - .addConstructorArgValue(authorizedClientService).getBeanDefinition(); - } - return authorizedClientRepository; - } - private BeanMetadataElement getAuthorizationRequestRepository(Element element) { BeanMetadataElement authorizationRequestRepository; String authorizationRequestRepositoryRef = element != null ? @@ -158,6 +134,10 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { return accessTokenResponseClient; } + BeanDefinition getDefaultAuthorizedClientRepository() { + return this.defaultAuthorizedClientRepository; + } + BeanDefinition getAuthorizationRequestRedirectFilter() { return this.authorizationRequestRedirectFilter; } diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserUtils.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserUtils.java new file mode 100644 index 0000000000..4ff56b1147 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserUtils.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.http; + +import org.springframework.beans.BeanMetadataElement; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.util.StringUtils; +import org.w3c.dom.Element; + +/** + * @author Joe Grandja + * @since 5.4 + */ +final class OAuth2ClientBeanDefinitionParserUtils { + private static final String ATT_CLIENT_REGISTRATION_REPOSITORY_REF = "client-registration-repository-ref"; + private static final String ATT_AUTHORIZED_CLIENT_REPOSITORY_REF = "authorized-client-repository-ref"; + private static final String ATT_AUTHORIZED_CLIENT_SERVICE_REF = "authorized-client-service-ref"; + + static BeanMetadataElement getClientRegistrationRepository(Element element) { + BeanMetadataElement clientRegistrationRepository; + String clientRegistrationRepositoryRef = element.getAttribute(ATT_CLIENT_REGISTRATION_REPOSITORY_REF); + if (!StringUtils.isEmpty(clientRegistrationRepositoryRef)) { + clientRegistrationRepository = new RuntimeBeanReference(clientRegistrationRepositoryRef); + } else { + clientRegistrationRepository = new RuntimeBeanReference(ClientRegistrationRepository.class); + } + return clientRegistrationRepository; + } + + static BeanMetadataElement getAuthorizedClientRepository(Element element) { + String authorizedClientRepositoryRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_REPOSITORY_REF); + if (!StringUtils.isEmpty(authorizedClientRepositoryRef)) { + return new RuntimeBeanReference(authorizedClientRepositoryRef); + } + return null; + } + + static BeanMetadataElement getAuthorizedClientService(Element element) { + String authorizedClientServiceRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_SERVICE_REF); + if (!StringUtils.isEmpty(authorizedClientServiceRef)) { + return new RuntimeBeanReference(authorizedClientServiceRef); + } + return null; + } + + static BeanMetadataElement createAuthorizedClientRepository(BeanMetadataElement authorizedClientService) { + return BeanDefinitionBuilder.rootBeanDefinition( + "org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository") + .addConstructorArgValue(authorizedClientService) + .getBeanDefinition(); + } + + static BeanDefinition createDefaultAuthorizedClientRepository(BeanMetadataElement clientRegistrationRepository) { + BeanDefinition authorizedClientService = BeanDefinitionBuilder.rootBeanDefinition( + "org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService") + .addConstructorArgValue(clientRegistrationRepository) + .getBeanDefinition(); + return BeanDefinitionBuilder.rootBeanDefinition( + "org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository") + .addConstructorArgValue(authorizedClientService) + .getBeanDefinition(); + } +} diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientWebMvcSecurityPostProcessor.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientWebMvcSecurityPostProcessor.java new file mode 100644 index 0000000000..c6d404179d --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientWebMvcSecurityPostProcessor.java @@ -0,0 +1,91 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.http; + +import org.springframework.beans.BeansException; +import org.springframework.beans.PropertyValue; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.beans.factory.support.ManagedList; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; +import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter; + +/** + * @author Joe Grandja + * @since 5.4 + */ +final class OAuth2ClientWebMvcSecurityPostProcessor implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware { + private static final String ARGUMENT_RESOLVERS_PROPERTY = "argumentResolvers"; + private BeanFactory beanFactory; + + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + String[] clientRegistrationRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + (ListableBeanFactory) this.beanFactory, ClientRegistrationRepository.class, false, false); + String[] authorizedClientRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + (ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientRepository.class, false, false); + + if (clientRegistrationRepositoryBeanNames.length != 1 || authorizedClientRepositoryBeanNames.length != 1) { + return; + } + + for (String beanName : registry.getBeanDefinitionNames()) { + BeanDefinition beanDefinition = registry.getBeanDefinition(beanName); + if (RequestMappingHandlerAdapter.class.getName().equals(beanDefinition.getBeanClassName())) { + PropertyValue currentArgumentResolvers = + beanDefinition.getPropertyValues().getPropertyValue(ARGUMENT_RESOLVERS_PROPERTY); + ManagedList argumentResolvers = new ManagedList<>(); + if (currentArgumentResolvers != null) { + argumentResolvers.addAll((ManagedList) currentArgumentResolvers.getValue()); + } + + String[] authorizedClientManagerBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + (ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientManager.class, false, false); + + BeanDefinitionBuilder beanDefinitionBuilder = + BeanDefinitionBuilder.genericBeanDefinition(OAuth2AuthorizedClientArgumentResolver.class); + if (authorizedClientManagerBeanNames.length == 1) { + beanDefinitionBuilder.addConstructorArgReference(authorizedClientManagerBeanNames[0]); + } else { + beanDefinitionBuilder.addConstructorArgReference(clientRegistrationRepositoryBeanNames[0]); + beanDefinitionBuilder.addConstructorArgReference(authorizedClientRepositoryBeanNames[0]); + } + argumentResolvers.add(beanDefinitionBuilder.getBeanDefinition()); + beanDefinition.getPropertyValues().add(ARGUMENT_RESOLVERS_PROPERTY, argumentResolvers); + break; + } + } + } + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.beanFactory = beanFactory; + } +} diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java index 0d2c7c4410..b5fcdda680 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java @@ -15,13 +15,6 @@ */ package org.springframework.security.config.http; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.BeanDefinition; @@ -66,6 +59,19 @@ import org.springframework.web.accept.ContentNegotiationStrategy; import org.springframework.web.accept.HeaderContentNegotiationStrategy; import org.w3c.dom.Element; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createAuthorizedClientRepository; +import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createDefaultAuthorizedClientRepository; +import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientRepository; +import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientService; +import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getClientRegistrationRepository; + /** * @author Ruby Hartono * @since 5.3 @@ -77,9 +83,6 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { private static final String ELT_CLIENT_REGISTRATION = "client-registration"; private static final String ATT_REGISTRATION_ID = "registration-id"; - private static final String ATT_CLIENT_REGISTRATION_REPOSITORY_REF = "client-registration-repository-ref"; - private static final String ATT_AUTHORIZED_CLIENT_REPOSITORY_REF = "authorized-client-repository-ref"; - private static final String ATT_AUTHORIZED_CLIENT_SERVICE_REF = "authorized-client-service-ref"; private static final String ATT_AUTHORIZATION_REQUEST_REPOSITORY_REF = "authorization-request-repository-ref"; private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref"; private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref"; @@ -98,6 +101,8 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { private final BeanReference sessionStrategy; private final boolean allowSessionCreation; + private BeanDefinition defaultAuthorizedClientRepository; + private BeanDefinition oauth2AuthorizationRequestRedirectFilter; private BeanDefinition oauth2LoginAuthenticationEntryPoint; @@ -128,8 +133,16 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { // configure filter BeanMetadataElement clientRegistrationRepository = getClientRegistrationRepository(element); - BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element, - clientRegistrationRepository); + BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element); + if (authorizedClientRepository == null) { + BeanMetadataElement authorizedClientService = getAuthorizedClientService(element); + if (authorizedClientService == null) { + this.defaultAuthorizedClientRepository = createDefaultAuthorizedClientRepository(clientRegistrationRepository); + authorizedClientRepository = this.defaultAuthorizedClientRepository; + } else { + authorizedClientRepository = createAuthorizedClientRepository(authorizedClientService); + } + } BeanMetadataElement accessTokenResponseClient = getAccessTokenResponseClient(element); BeanMetadataElement oauth2UserService = getOAuth2UserService(element); BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository(element); @@ -251,41 +264,6 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { return authorizationRequestRepository; } - private BeanMetadataElement getAuthorizedClientRepository(Element element, - BeanMetadataElement clientRegistrationRepository) { - BeanMetadataElement authorizedClientRepository; - String authorizedClientRepositoryRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_REPOSITORY_REF); - if (!StringUtils.isEmpty(authorizedClientRepositoryRef)) { - authorizedClientRepository = new RuntimeBeanReference(authorizedClientRepositoryRef); - } else { - BeanMetadataElement authorizedClientService; - String authorizedClientServiceRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_SERVICE_REF); - if (!StringUtils.isEmpty(authorizedClientServiceRef)) { - authorizedClientService = new RuntimeBeanReference(authorizedClientServiceRef); - } else { - authorizedClientService = BeanDefinitionBuilder - .rootBeanDefinition( - "org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService") - .addConstructorArgValue(clientRegistrationRepository).getBeanDefinition(); - } - authorizedClientRepository = BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository") - .addConstructorArgValue(authorizedClientService).getBeanDefinition(); - } - return authorizedClientRepository; - } - - private BeanMetadataElement getClientRegistrationRepository(Element element) { - BeanMetadataElement clientRegistrationRepository; - String clientRegistrationRepositoryRef = element.getAttribute(ATT_CLIENT_REGISTRATION_REPOSITORY_REF); - if (!StringUtils.isEmpty(clientRegistrationRepositoryRef)) { - clientRegistrationRepository = new RuntimeBeanReference(clientRegistrationRepositoryRef); - } else { - clientRegistrationRepository = new RuntimeBeanReference(ClientRegistrationRepository.class); - } - return clientRegistrationRepository; - } - private BeanDefinition getOidcAuthProvider(Element element, BeanMetadataElement accessTokenResponseClient, String userAuthoritiesMapperRef) { @@ -353,6 +331,10 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { return accessTokenResponseClient; } + BeanDefinition getDefaultAuthorizedClientRepository() { + return this.defaultAuthorizedClientRepository; + } + BeanDefinition getOAuth2AuthorizationRequestRedirectFilter() { return oauth2AuthorizationRequestRedirectFilter; } diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java index 6774cfcdf2..c466ed8e35 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java @@ -24,6 +24,7 @@ import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -31,6 +32,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -41,6 +43,8 @@ import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; import java.util.HashMap; import java.util.Map; @@ -51,6 +55,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @@ -200,6 +205,32 @@ public class OAuth2ClientBeanDefinitionParserTests { verify(this.authorizedClientService).saveAuthorizedClient(any(), any()); } + @WithMockUser + @Test + public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws Exception { + this.spring.configLocations(xml("AuthorizedClientArgumentResolver")).autowire(); + + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google"); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, "user", TestOAuth2AccessTokens.noScopes()); + when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())) + .thenReturn(authorizedClient); + + this.mvc.perform(get("/authorized-client")) + .andExpect(status().isOk()) + .andExpect(content().string("resolved")); + } + + @RestController + static class AuthorizedClientController { + + @GetMapping("/authorized-client") + String authorizedClient(@RegisteredOAuth2AuthorizedClient("google") OAuth2AuthorizedClient authorizedClient) { + return authorizedClient != null ? "resolved" : "not-resolved"; + } + } + private static OAuth2AuthorizationRequest createAuthorizationRequest(ClientRegistration clientRegistration) { Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java index 9470cd03b4..0fa4938b5c 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java @@ -17,6 +17,7 @@ package org.springframework.security.config.http; import org.junit.Rule; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; @@ -28,7 +29,9 @@ import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -40,6 +43,7 @@ import org.springframework.security.oauth2.client.web.AuthorizationRequestReposi import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -50,13 +54,22 @@ import org.springframework.security.oauth2.core.user.TestOAuth2Users; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; +import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.savedrequest.RequestCache; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -66,18 +79,17 @@ import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse; import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.oidcAccessTokenResponse; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; - /** * Tests for {@link OAuth2LoginBeanDefinitionParser}. * * @author Ruby Hartono */ +@RunWith(SpringJUnit4ClassRunner.class) +@SecurityTestExecutionListeners public class OAuth2LoginBeanDefinitionParserTests { private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests"; @@ -489,6 +501,32 @@ public class OAuth2LoginBeanDefinitionParserTests { verify(authorizedClientService).saveAuthorizedClient(any(), any()); } + @WithMockUser + @Test + public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws Exception { + this.spring.configLocations(xml("AuthorizedClientArgumentResolver")).autowire(); + + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google-login"); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, "user", TestOAuth2AccessTokens.noScopes()); + when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())) + .thenReturn(authorizedClient); + + this.mvc.perform(get("/authorized-client")) + .andExpect(status().isOk()) + .andExpect(content().string("resolved")); + } + + @RestController + static class AuthorizedClientController { + + @GetMapping("/authorized-client") + String authorizedClient(@RegisteredOAuth2AuthorizedClient("google") OAuth2AuthorizedClient authorizedClient) { + return authorizedClient != null ? "resolved" : "not-resolved"; + } + } + private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-AuthorizedClientArgumentResolver.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-AuthorizedClientArgumentResolver.xml new file mode 100644 index 0000000000..bf59662367 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-AuthorizedClientArgumentResolver.xml @@ -0,0 +1,52 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-AuthorizedClientArgumentResolver.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-AuthorizedClientArgumentResolver.xml new file mode 100644 index 0000000000..dd0afc9351 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-AuthorizedClientArgumentResolver.xml @@ -0,0 +1,45 @@ + + + + + + + + + + + + + + + + + + + + +