From 197737a2b4539b8dc4d31f8ae016af77d4762f60 Mon Sep 17 00:00:00 2001 From: Luke Taylor Date: Wed, 4 Nov 2009 14:55:58 +0000 Subject: [PATCH] SEC-1281: make sure correct 'key' value is used for RememberMeAuthenticationProvider when external RememberMeServices is used --- .../http/AuthenticationConfigBuilder.java | 28 +++++++++---------- .../http/RememberMeBeanDefinitionParser.java | 15 ++++------ ...HttpSecurityBeanDefinitionParserTests.java | 23 +++++++++++++-- 3 files changed, 39 insertions(+), 27 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java index efab8e7b31..627f02fec8 100644 --- a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java +++ b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java @@ -118,33 +118,31 @@ final class AuthenticationConfigBuilder { } void createRememberMeFilter(BeanReference authenticationManager) { + final String ATT_KEY = "key"; + final String DEF_KEY = "SpringSecured"; + // Parse remember me before logout as RememberMeServices is also a LogoutHandler implementation. Element rememberMeElt = DomUtils.getChildElementByTagName(httpElt, Elements.REMEMBER_ME); if (rememberMeElt != null) { - rememberMeFilter = (RootBeanDefinition) new RememberMeBeanDefinitionParser().parse(rememberMeElt, pc); + String key = rememberMeElt.getAttribute(ATT_KEY); + + if (!StringUtils.hasText(key)) { + key = DEF_KEY; + } + + rememberMeFilter = (RootBeanDefinition) new RememberMeBeanDefinitionParser(key).parse(rememberMeElt, pc); rememberMeFilter.getPropertyValues().addPropertyValue("authenticationManager", authenticationManager); rememberMeServicesId = ((RuntimeBeanReference) rememberMeFilter.getPropertyValues().getPropertyValue("rememberMeServices").getValue()).getBeanName(); - createRememberMeProvider(); + createRememberMeProvider(key); } } - private void createRememberMeProvider() { + private void createRememberMeProvider(String key) { RootBeanDefinition provider = new RootBeanDefinition(RememberMeAuthenticationProvider.class); provider.setSource(rememberMeFilter.getSource()); - // Locate the RememberMeServices bean and read the "key" property from it - PropertyValue key = null; - if (pc.getRegistry().containsBeanDefinition(rememberMeServicesId)) { - BeanDefinition services = pc.getRegistry().getBeanDefinition(rememberMeServicesId); - key = services.getPropertyValues().getPropertyValue("key"); - } - - if (key == null) { - key = new PropertyValue("key", RememberMeBeanDefinitionParser.DEF_KEY); - } - - provider.getPropertyValues().addPropertyValue(key); + provider.getPropertyValues().addPropertyValue("key", key); String id = pc.getReaderContext().registerWithGeneratedName(provider); pc.registerBeanComponent(new BeanComponentDefinition(provider, id)); diff --git a/config/src/main/java/org/springframework/security/config/http/RememberMeBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/RememberMeBeanDefinitionParser.java index f09fef5160..8dae5ebfa3 100644 --- a/config/src/main/java/org/springframework/security/config/http/RememberMeBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/RememberMeBeanDefinitionParser.java @@ -23,9 +23,6 @@ import org.w3c.dom.Element; * @version $Id$ */ class RememberMeBeanDefinitionParser implements BeanDefinitionParser { - static final String ATT_KEY = "key"; - static final String DEF_KEY = "SpringSecured"; - static final String ATT_DATA_SOURCE = "data-source-ref"; static final String ATT_SERVICES_REF = "services-ref"; static final String ATT_SERVICES_ALIAS = "services-alias"; @@ -36,6 +33,11 @@ class RememberMeBeanDefinitionParser implements BeanDefinitionParser { protected final Log logger = LogFactory.getLog(getClass()); private String servicesName; + private final String key; + + RememberMeBeanDefinitionParser(String key) { + this.key = key; + } public BeanDefinition parse(Element element, ParserContext pc) { CompositeComponentDefinition compositeDef = @@ -44,16 +46,11 @@ class RememberMeBeanDefinitionParser implements BeanDefinitionParser { String tokenRepository = element.getAttribute(ATT_TOKEN_REPOSITORY); String dataSource = element.getAttribute(ATT_DATA_SOURCE); - String key = element.getAttribute(ATT_KEY); String userServiceRef = element.getAttribute(ATT_USER_SERVICE_REF); String rememberMeServicesRef = element.getAttribute(ATT_SERVICES_REF); String tokenValiditySeconds = element.getAttribute(ATT_TOKEN_VALIDITY); Object source = pc.extractSource(element); - if (!StringUtils.hasText(key)) { - key = DEF_KEY; - } - RootBeanDefinition services = null; boolean dataSourceSet = StringUtils.hasText(dataSource); @@ -108,7 +105,7 @@ class RememberMeBeanDefinitionParser implements BeanDefinitionParser { services.getPropertyValues().addPropertyValue("tokenValiditySeconds", tokenValidity); } services.setSource(source); - services.getPropertyValues().addPropertyValue(ATT_KEY, key); + services.getPropertyValues().addPropertyValue("key", key); servicesName = pc.getReaderContext().registerWithGeneratedName(services); pc.registerBeanComponent(new BeanComponentDefinition(services, servicesName)); } else { diff --git a/config/src/test/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParserTests.java index 463f0a8cb0..3b2273944b 100644 --- a/config/src/test/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParserTests.java @@ -1,9 +1,19 @@ package org.springframework.security.config.http; -import static org.hamcrest.Matchers.*; -import static org.junit.Assert.*; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.springframework.security.config.ConfigTestUtils.AUTH_PROVIDER_XML; -import static org.springframework.security.config.http.AuthenticationConfigBuilder.*; +import static org.springframework.security.config.http.AuthenticationConfigBuilder.AUTHENTICATION_PROCESSING_FILTER_CLASS; +import static org.springframework.security.config.http.AuthenticationConfigBuilder.OPEN_ID_AUTHENTICATION_PROCESSING_FILTER_CLASS; +import static org.springframework.security.config.http.AuthenticationConfigBuilder.OPEN_ID_AUTHENTICATION_PROVIDER_CLASS; import java.lang.reflect.Method; import java.util.ArrayList; @@ -26,6 +36,8 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; +import org.springframework.security.authentication.ProviderManager; +import org.springframework.security.authentication.RememberMeAuthenticationProvider; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.config.BeanIds; @@ -579,6 +591,11 @@ public class HttpSecurityBeanDefinitionParserTests { List logoutHandlers = (List) FieldUtils.getFieldValue(getFilter(LogoutFilter.class), "handlers"); assertEquals(2, logoutHandlers.size()); assertEquals(getRememberMeServices(), logoutHandlers.get(1)); + // SEC-1281 + Map ams = appContext.getBeansOfType(ProviderManager.class); + ams.remove(BeanIds.AUTHENTICATION_MANAGER); + RememberMeAuthenticationProvider rmp = (RememberMeAuthenticationProvider) ((ProviderManager)ams.values().toArray()[0]).getProviders().get(1); + assertEquals("ourkey", rmp.getKey()); } @Test