From d3a9cc6eae56fb9826d86f021e71faa482675abe Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 12 Apr 2016 16:26:53 -0500 Subject: [PATCH] Add CsrfTokenRepository (#3805) * Create LazyCsrfTokenRepository Fixes gh-3790 * Add CookieCsrfTokenRepository Fixes gh-3009 --- .../web/configurers/CsrfConfigurer.java | 53 +-- .../config/http/CsrfBeanDefinitionParser.java | 33 +- .../configurers/CsrfConfigurerTests.groovy | 393 ++++++++++-------- .../config/http/CsrfConfigTests.groovy | 347 ++++++++-------- docs/manual/src/docs/asciidoc/index.adoc | 42 +- .../SecurityMockMvcRequestPostProcessors.java | 183 ++++---- .../web/csrf/CookieCsrfTokenRepository.java | 126 ++++++ .../web/csrf/CsrfAuthenticationStrategy.java | 93 +---- .../security/web/csrf/CsrfFilter.java | 126 ++---- .../web/csrf/LazyCsrfTokenRepository.java | 186 +++++++++ .../csrf/CookieCsrfTokenRepositoryTests.java | 189 +++++++++ .../csrf/CsrfAuthenticationStrategyTests.java | 90 ++-- .../security/web/csrf/CsrfFilterTests.java | 355 +++++++++------- .../csrf/LazyCsrfTokenRepositoryTests.java | 102 +++++ 14 files changed, 1461 insertions(+), 857 deletions(-) create mode 100644 web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java create mode 100644 web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java create mode 100644 web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java create mode 100644 web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java index e4cf68dd49..b0a146c10d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java @@ -33,6 +33,7 @@ import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfLogoutHandler; import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; +import org.springframework.security.web.csrf.LazyCsrfTokenRepository; import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.security.web.session.InvalidSessionAccessDeniedHandler; import org.springframework.security.web.session.InvalidSessionStrategy; @@ -43,8 +44,9 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; /** - * Adds CSRF protection for the methods as specified by + * Adds + * CSRF + * protection for the methods as specified by * {@link #requireCsrfProtectionMatcher(RequestMatcher)}. * *

Security Filters

@@ -62,18 +64,18 @@ import org.springframework.util.Assert; *

Shared Objects Used

* * * * @author Rob Winch * @since 3.2 */ -public final class CsrfConfigurer> extends - AbstractHttpConfigurer, H> { - private CsrfTokenRepository csrfTokenRepository = new HttpSessionCsrfTokenRepository(); +public final class CsrfConfigurer> + extends AbstractHttpConfigurer, H> { + private CsrfTokenRepository csrfTokenRepository = new LazyCsrfTokenRepository( + new HttpSessionCsrfTokenRepository()); private RequestMatcher requireCsrfProtectionMatcher = CsrfFilter.DEFAULT_CSRF_MATCHER; private List ignoredCsrfProtectionMatchers = new ArrayList(); @@ -86,12 +88,13 @@ public final class CsrfConfigurer> extends /** * Specify the {@link CsrfTokenRepository} to use. The default is an - * {@link HttpSessionCsrfTokenRepository}. + * {@link HttpSessionCsrfTokenRepository} wrapped by {@link LazyCsrfTokenRepository}. * * @param csrfTokenRepository the {@link CsrfTokenRepository} to use * @return the {@link CsrfConfigurer} for further customizations */ - public CsrfConfigurer csrfTokenRepository(CsrfTokenRepository csrfTokenRepository) { + public CsrfConfigurer csrfTokenRepository( + CsrfTokenRepository csrfTokenRepository) { Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); this.csrfTokenRepository = csrfTokenRepository; return this; @@ -144,7 +147,7 @@ public final class CsrfConfigurer> extends @SuppressWarnings("unchecked") @Override public void configure(H http) throws Exception { - CsrfFilter filter = new CsrfFilter(csrfTokenRepository); + CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository); RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher(); if (requireCsrfProtectionMatcher != null) { filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); @@ -155,14 +158,14 @@ public final class CsrfConfigurer> extends } LogoutConfigurer logoutConfigurer = http.getConfigurer(LogoutConfigurer.class); if (logoutConfigurer != null) { - logoutConfigurer.addLogoutHandler(new CsrfLogoutHandler(csrfTokenRepository)); + logoutConfigurer + .addLogoutHandler(new CsrfLogoutHandler(this.csrfTokenRepository)); } SessionManagementConfigurer sessionConfigurer = http .getConfigurer(SessionManagementConfigurer.class); if (sessionConfigurer != null) { - sessionConfigurer - .addSessionAuthenticationStrategy(new CsrfAuthenticationStrategy( - csrfTokenRepository)); + sessionConfigurer.addSessionAuthenticationStrategy( + new CsrfAuthenticationStrategy(this.csrfTokenRepository)); } filter = postProcess(filter); http.addFilter(filter); @@ -175,12 +178,12 @@ public final class CsrfConfigurer> extends * @return the {@link RequestMatcher} to use */ private RequestMatcher getRequireCsrfProtectionMatcher() { - if (ignoredCsrfProtectionMatchers.isEmpty()) { - return requireCsrfProtectionMatcher; + if (this.ignoredCsrfProtectionMatchers.isEmpty()) { + return this.requireCsrfProtectionMatcher; } - return new AndRequestMatcher(requireCsrfProtectionMatcher, - new NegatedRequestMatcher(new OrRequestMatcher( - ignoredCsrfProtectionMatchers))); + return new AndRequestMatcher(this.requireCsrfProtectionMatcher, + new NegatedRequestMatcher( + new OrRequestMatcher(this.ignoredCsrfProtectionMatchers))); } /** @@ -238,7 +241,8 @@ public final class CsrfConfigurer> extends */ private AccessDeniedHandler createAccessDeniedHandler(H http) { InvalidSessionStrategy invalidSessionStrategy = getInvalidSessionStrategy(http); - AccessDeniedHandler defaultAccessDeniedHandler = getDefaultAccessDeniedHandler(http); + AccessDeniedHandler defaultAccessDeniedHandler = getDefaultAccessDeniedHandler( + http); if (invalidSessionStrategy == null) { return defaultAccessDeniedHandler; } @@ -258,16 +262,17 @@ public final class CsrfConfigurer> extends * @author Rob Winch * @since 4.0 */ - private class IgnoreCsrfProtectionRegistry extends - AbstractRequestMatcherRegistry { + private class IgnoreCsrfProtectionRegistry + extends AbstractRequestMatcherRegistry { public CsrfConfigurer and() { return CsrfConfigurer.this; } + @Override protected IgnoreCsrfProtectionRegistry chainRequestMatchers( List requestMatchers) { - ignoredCsrfProtectionMatchers.addAll(requestMatchers); + CsrfConfigurer.this.ignoredCsrfProtectionMatchers.addAll(requestMatchers); return this; } } diff --git a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java index 19dadf5605..ed2095e258 100644 --- a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java @@ -15,6 +15,8 @@ */ package org.springframework.security.config.http; +import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.parsing.BeanComponentDefinition; @@ -31,13 +33,13 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfLogoutHandler; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; +import org.springframework.security.web.csrf.LazyCsrfTokenRepository; import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor; import org.springframework.security.web.session.InvalidSessionAccessDeniedHandler; import org.springframework.security.web.session.InvalidSessionStrategy; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; /** * Parser for the {@code CsrfFilter}. @@ -55,6 +57,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { private String csrfRepositoryRef; private BeanDefinition csrfFilter; + @Override public BeanDefinition parse(Element element, ParserContext pc) { boolean disabled = element != null && "true".equals(element.getAttribute("disabled")); @@ -73,29 +76,33 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { String matcherRef = null; if (element != null) { - csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY); + this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY); matcherRef = element.getAttribute(ATT_MATCHER); } - if (!StringUtils.hasText(csrfRepositoryRef)) { + if (!StringUtils.hasText(this.csrfRepositoryRef)) { + RootBeanDefinition csrfTokenRepository = new RootBeanDefinition( HttpSessionCsrfTokenRepository.class); - csrfRepositoryRef = pc.getReaderContext().generateBeanName( - csrfTokenRepository); - pc.registerBeanComponent(new BeanComponentDefinition(csrfTokenRepository, - csrfRepositoryRef)); + BeanDefinitionBuilder lazyTokenRepository = BeanDefinitionBuilder + .rootBeanDefinition(LazyCsrfTokenRepository.class); + lazyTokenRepository.addConstructorArgValue(csrfTokenRepository); + this.csrfRepositoryRef = pc.getReaderContext() + .generateBeanName(lazyTokenRepository.getBeanDefinition()); + pc.registerBeanComponent(new BeanComponentDefinition( + lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef)); } BeanDefinitionBuilder builder = BeanDefinitionBuilder .rootBeanDefinition(CsrfFilter.class); - builder.addConstructorArgReference(csrfRepositoryRef); + builder.addConstructorArgReference(this.csrfRepositoryRef); if (StringUtils.hasText(matcherRef)) { builder.addPropertyReference("requireCsrfProtectionMatcher", matcherRef); } - csrfFilter = builder.getBeanDefinition(); - return csrfFilter; + this.csrfFilter = builder.getBeanDefinition(); + return this.csrfFilter; } /** @@ -108,7 +115,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { BeanMetadataElement defaultDeniedHandler) { BeanMetadataElement accessDeniedHandler = createAccessDeniedHandler( invalidSessionStrategy, defaultDeniedHandler); - csrfFilter.getPropertyValues().addPropertyValue("accessDeniedHandler", + this.csrfFilter.getPropertyValues().addPropertyValue("accessDeniedHandler", accessDeniedHandler); } @@ -152,14 +159,14 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { BeanDefinition getCsrfAuthenticationStrategy() { BeanDefinitionBuilder csrfAuthenticationStrategy = BeanDefinitionBuilder .rootBeanDefinition(CsrfAuthenticationStrategy.class); - csrfAuthenticationStrategy.addConstructorArgReference(csrfRepositoryRef); + csrfAuthenticationStrategy.addConstructorArgReference(this.csrfRepositoryRef); return csrfAuthenticationStrategy.getBeanDefinition(); } BeanDefinition getCsrfLogoutHandler() { BeanDefinitionBuilder csrfAuthenticationStrategy = BeanDefinitionBuilder .rootBeanDefinition(CsrfLogoutHandler.class); - csrfAuthenticationStrategy.addConstructorArgReference(csrfRepositoryRef); + csrfAuthenticationStrategy.addConstructorArgReference(this.csrfRepositoryRef); return csrfAuthenticationStrategy.getBeanDefinition(); } } diff --git a/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.groovy b/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.groovy index a2366e281e..12e8448c0a 100644 --- a/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.groovy +++ b/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.groovy @@ -15,11 +15,10 @@ */ package org.springframework.security.config.annotation.web.configurers -import org.springframework.security.web.util.matcher.AntPathRequestMatcher - import javax.servlet.http.HttpServletResponse -import org.springframework.context.annotation.Configuration +import spock.lang.Unroll + import org.springframework.mock.web.MockHttpServletRequest import org.springframework.mock.web.MockHttpServletResponse import org.springframework.security.config.annotation.BaseSpringSpec @@ -27,15 +26,13 @@ import org.springframework.security.config.annotation.authentication.builders.Au import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter -import org.springframework.security.config.annotation.web.servlet.configuration.EnableWebMvcSecurity; import org.springframework.security.web.access.AccessDeniedHandler import org.springframework.security.web.csrf.CsrfFilter import org.springframework.security.web.csrf.CsrfTokenRepository +import org.springframework.security.web.util.matcher.AntPathRequestMatcher import org.springframework.security.web.util.matcher.RequestMatcher import org.springframework.web.servlet.support.RequestDataValueProcessor -import spock.lang.Unroll - /** * * @author Rob Winch @@ -45,31 +42,31 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Unroll def "csrf applied by default"() { setup: - loadConfig(CsrfAppliedDefaultConfig) - request.method = httpMethod - clearCsrfToken() + loadConfig(CsrfAppliedDefaultConfig) + request.method = httpMethod + clearCsrfToken() when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - response.status == httpStatus + response.status == httpStatus where: - httpMethod | httpStatus - 'POST' | HttpServletResponse.SC_FORBIDDEN - 'PUT' | HttpServletResponse.SC_FORBIDDEN - 'PATCH' | HttpServletResponse.SC_FORBIDDEN - 'DELETE' | HttpServletResponse.SC_FORBIDDEN - 'INVALID' | HttpServletResponse.SC_FORBIDDEN - 'GET' | HttpServletResponse.SC_OK - 'HEAD' | HttpServletResponse.SC_OK - 'TRACE' | HttpServletResponse.SC_OK - 'OPTIONS' | HttpServletResponse.SC_OK + httpMethod | httpStatus + 'POST' | HttpServletResponse.SC_FORBIDDEN + 'PUT' | HttpServletResponse.SC_FORBIDDEN + 'PATCH' | HttpServletResponse.SC_FORBIDDEN + 'DELETE' | HttpServletResponse.SC_FORBIDDEN + 'INVALID' | HttpServletResponse.SC_FORBIDDEN + 'GET' | HttpServletResponse.SC_OK + 'HEAD' | HttpServletResponse.SC_OK + 'TRACE' | HttpServletResponse.SC_OK + 'OPTIONS' | HttpServletResponse.SC_OK } def "csrf default creates CsrfRequestDataValueProcessor"() { when: - loadConfig(CsrfAppliedDefaultConfig) + loadConfig(CsrfAppliedDefaultConfig) then: - context.getBean(RequestDataValueProcessor) + context.getBean(RequestDataValueProcessor) } @EnableWebSecurity @@ -82,14 +79,14 @@ class CsrfConfigurerTests extends BaseSpringSpec { def "csrf disable"() { setup: - loadConfig(DisableCsrfConfig) - request.method = "POST" - clearCsrfToken() + loadConfig(DisableCsrfConfig) + request.method = "POST" + clearCsrfToken() when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - !findFilter(CsrfFilter) - response.status == HttpServletResponse.SC_OK + !findFilter(CsrfFilter) + response.status == HttpServletResponse.SC_OK } @EnableWebSecurity @@ -98,29 +95,29 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Override protected void configure(HttpSecurity http) throws Exception { http - .csrf().disable() + .csrf().disable() } } def "SEC-2498: Disable CSRF enables RequestCache for any method"() { setup: - loadConfig(DisableCsrfEnablesRequestCacheConfig) - request.requestURI = '/tosave' - request.method = "POST" - clearCsrfToken() + loadConfig(DisableCsrfEnablesRequestCacheConfig) + request.requestURI = '/tosave' + request.method = "POST" + clearCsrfToken() when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - response.redirectedUrl + response.redirectedUrl when: - super.setupWeb(request.session) - request.method = "POST" - request.servletPath = '/login' - request.parameters['username'] = ['user'] as String[] - request.parameters['password'] = ['password'] as String[] - springSecurityFilterChain.doFilter(request,response,chain) + super.setupWeb(request.session) + request.method = "POST" + request.servletPath = '/login' + request.parameters['username'] = ['user'] as String[] + request.parameters['password'] = ['password'] as String[] + springSecurityFilterChain.doFilter(request,response,chain) then: - response.redirectedUrl == 'http://localhost/tosave' + response.redirectedUrl == 'http://localhost/tosave' } @EnableWebSecurity @@ -129,38 +126,37 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Override protected void configure(HttpSecurity http) throws Exception { http - .authorizeRequests() + .authorizeRequests() .anyRequest().authenticated() .and() - .formLogin().and() - .csrf().disable() - + .formLogin().and() + .csrf().disable() } @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { auth - .inMemoryAuthentication() + .inMemoryAuthentication() .withUser("user").password("password").roles("USER") } } def "SEC-2422: csrf expire CSRF token and session-management invalid-session-url"() { setup: - loadConfig(InvalidSessionUrlConfig) - request.session.clearAttributes() - request.setParameter("_csrf","abc") - request.method = "POST" + loadConfig(InvalidSessionUrlConfig) + request.session.clearAttributes() + request.setParameter("_csrf","abc") + request.method = "POST" when: "No existing expected CsrfToken (session times out) and a POST" - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: "sent to the session timeout page page" - response.status == HttpServletResponse.SC_MOVED_TEMPORARILY - response.redirectedUrl == "/error/sessionError" + response.status == HttpServletResponse.SC_MOVED_TEMPORARILY + response.redirectedUrl == "/error/sessionError" when: "Existing expected CsrfToken and a POST (invalid token provided)" - response = new MockHttpServletResponse() - request = new MockHttpServletRequest(session: request.session, method:'POST') - springSecurityFilterChain.doFilter(request,response,chain) + response = new MockHttpServletResponse() + request = new MockHttpServletRequest(session: request.session, method:'POST') + springSecurityFilterChain.doFilter(request,response,chain) then: "Access Denied occurs" - response.status == HttpServletResponse.SC_FORBIDDEN + response.status == HttpServletResponse.SC_FORBIDDEN } @EnableWebSecurity @@ -168,26 +164,26 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Override protected void configure(HttpSecurity http) throws Exception { http - .csrf().and() - .sessionManagement() + .csrf().and() + .sessionManagement() .invalidSessionUrl("/error/sessionError") } } def "csrf requireCsrfProtectionMatcher"() { setup: - RequireCsrfProtectionMatcherConfig.matcher = Mock(RequestMatcher) - RequireCsrfProtectionMatcherConfig.matcher.matches(_) >>> [false,true] - loadConfig(RequireCsrfProtectionMatcherConfig) - clearCsrfToken() + RequireCsrfProtectionMatcherConfig.matcher = Mock(RequestMatcher) + RequireCsrfProtectionMatcherConfig.matcher.matches(_) >>> [false, true] + loadConfig(RequireCsrfProtectionMatcherConfig) + clearCsrfToken() when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - response.status == HttpServletResponse.SC_OK + response.status == HttpServletResponse.SC_OK when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - response.status == HttpServletResponse.SC_FORBIDDEN + response.status == HttpServletResponse.SC_FORBIDDEN } @EnableWebSecurity @@ -197,53 +193,53 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Override protected void configure(HttpSecurity http) throws Exception { http - .csrf() + .csrf() .requireCsrfProtectionMatcher(matcher) } } def "csrf csrfTokenRepository"() { setup: - CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository) - loadConfig(CsrfTokenRepositoryConfig) - clearCsrfToken() + CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository) + loadConfig(CsrfTokenRepositoryConfig) + clearCsrfToken() when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - 1 * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken - response.status == HttpServletResponse.SC_OK + 1 * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken + response.status == HttpServletResponse.SC_OK } def "csrf clears on logout"() { setup: - CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository) - 1 * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken - loadConfig(CsrfTokenRepositoryConfig) - login() - request.method = "POST" - request.servletPath = "/logout" + CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository) + 1 * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken + loadConfig(CsrfTokenRepositoryConfig) + login() + request.method = "POST" + request.servletPath = "/logout" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - 1 * CsrfTokenRepositoryConfig.repo.saveToken(null, _, _) + 1 * CsrfTokenRepositoryConfig.repo.saveToken(null, _, _) } def "csrf clears on login"() { setup: - CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository) - (1.._) * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken - (1.._) * CsrfTokenRepositoryConfig.repo.generateToken(_) >> csrfToken - loadConfig(CsrfTokenRepositoryConfig) - request.method = "POST" - request.getSession() - request.servletPath = "/login" - request.setParameter("username", "user") - request.setParameter("password", "password") + CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository) + (1.._) * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken + (1.._) * CsrfTokenRepositoryConfig.repo.generateToken(_) >> csrfToken + loadConfig(CsrfTokenRepositoryConfig) + request.method = "POST" + request.getSession() + request.servletPath = "/login" + request.setParameter("username", "user") + request.setParameter("password", "password") when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - response.redirectedUrl == "/" - (1.._) * CsrfTokenRepositoryConfig.repo.saveToken(null, _, _) + response.redirectedUrl == "/" + (1.._) * CsrfTokenRepositoryConfig.repo.saveToken(null, _, _) } @EnableWebSecurity @@ -253,30 +249,30 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Override protected void configure(HttpSecurity http) throws Exception { http - .formLogin() + .formLogin() .and() - .csrf() + .csrf() .csrfTokenRepository(repo) } @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { auth - .inMemoryAuthentication() + .inMemoryAuthentication() .withUser("user").password("password").roles("USER") } } def "csrf access denied handler"() { setup: - AccessDeniedHandlerConfig.deniedHandler = Mock(AccessDeniedHandler) - 1 * AccessDeniedHandlerConfig.deniedHandler.handle(_, _, _) - loadConfig(AccessDeniedHandlerConfig) - clearCsrfToken() - request.method = "POST" + AccessDeniedHandlerConfig.deniedHandler = Mock(AccessDeniedHandler) + 1 * AccessDeniedHandlerConfig.deniedHandler.handle(_, _, _) + loadConfig(AccessDeniedHandlerConfig) + clearCsrfToken() + request.method = "POST" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - response.status == HttpServletResponse.SC_OK + response.status == HttpServletResponse.SC_OK } @EnableWebSecurity @@ -286,24 +282,24 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Override protected void configure(HttpSecurity http) throws Exception { http - .exceptionHandling() + .exceptionHandling() .accessDeniedHandler(deniedHandler) } } def "formLogin requires CSRF token"() { setup: - loadConfig(FormLoginConfig) - clearCsrfToken() - request.setParameter("username", "user") - request.setParameter("password", "password") - request.servletPath = "/login" - request.method = "POST" + loadConfig(FormLoginConfig) + clearCsrfToken() + request.setParameter("username", "user") + request.setParameter("password", "password") + request.servletPath = "/login" + request.method = "POST" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - response.status == HttpServletResponse.SC_FORBIDDEN - currentAuthentication == null + response.status == HttpServletResponse.SC_FORBIDDEN + currentAuthentication == null } @EnableWebSecurity @@ -313,34 +309,34 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Override protected void configure(HttpSecurity http) throws Exception { http - .formLogin() + .formLogin() } } def "logout requires CSRF token"() { setup: - loadConfig(LogoutConfig) - clearCsrfToken() - login() - request.servletPath = "/logout" - request.method = "POST" + loadConfig(LogoutConfig) + clearCsrfToken() + login() + request.servletPath = "/logout" + request.method = "POST" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: "logout is not allowed and user is still authenticated" - response.status == HttpServletResponse.SC_FORBIDDEN - currentAuthentication != null + response.status == HttpServletResponse.SC_FORBIDDEN + currentAuthentication != null } def "SEC-2543: CSRF means logout requires POST"() { setup: - loadConfig(LogoutConfig) - login() - request.servletPath = "/logout" - request.method = "GET" + loadConfig(LogoutConfig) + login() + request.servletPath = "/logout" + request.method = "GET" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: "logout with GET is not performed" - currentAuthentication != null + currentAuthentication != null } @EnableWebSecurity @@ -350,20 +346,20 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Override protected void configure(HttpSecurity http) throws Exception { http - .formLogin() + .formLogin() } } def "CSRF can explicitly enable GET for logout"() { setup: - loadConfig(LogoutAllowsGetConfig) - login() - request.servletPath = "/logout" - request.method = "GET" + loadConfig(LogoutAllowsGetConfig) + login() + request.servletPath = "/logout" + request.method = "GET" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: "logout with GET is not performed" - currentAuthentication == null + currentAuthentication == null } @EnableWebSecurity @@ -373,64 +369,64 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Override protected void configure(HttpSecurity http) throws Exception { http - .formLogin().and() - .logout() + .formLogin().and() + .logout() .logoutRequestMatcher(new AntPathRequestMatcher("/logout")) } } def "csrf disables POST requests from RequestCache"() { setup: - CsrfDisablesPostRequestFromRequestCacheConfig.repo = Mock(CsrfTokenRepository) - (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.generateToken(_) >> csrfToken - loadConfig(CsrfDisablesPostRequestFromRequestCacheConfig) - request.servletPath = "/some-url" - request.requestURI = "/some-url" - request.method = "POST" + CsrfDisablesPostRequestFromRequestCacheConfig.repo = Mock(CsrfTokenRepository) + (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.generateToken(_) >> csrfToken + loadConfig(CsrfDisablesPostRequestFromRequestCacheConfig) + request.servletPath = "/some-url" + request.requestURI = "/some-url" + request.method = "POST" when: "CSRF passes and our session times out" - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: "sent to the login page" - (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken - response.status == HttpServletResponse.SC_MOVED_TEMPORARILY - response.redirectedUrl == "http://localhost/login" + (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken + response.status == HttpServletResponse.SC_MOVED_TEMPORARILY + response.redirectedUrl == "http://localhost/login" when: "authenticate successfully" - super.setupWeb(request.session) - request.servletPath = "/login" - request.setParameter("username","user") - request.setParameter("password","password") - request.method = "POST" - springSecurityFilterChain.doFilter(request,response,chain) + super.setupWeb(request.session) + request.servletPath = "/login" + request.setParameter("username","user") + request.setParameter("password","password") + request.method = "POST" + springSecurityFilterChain.doFilter(request,response,chain) then: "sent to default success because we don't want csrf attempts made prior to authentication to pass" - (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken - response.status == HttpServletResponse.SC_MOVED_TEMPORARILY - response.redirectedUrl == "/" + (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken + response.status == HttpServletResponse.SC_MOVED_TEMPORARILY + response.redirectedUrl == "/" } def "csrf enables GET requests with RequestCache"() { setup: - CsrfDisablesPostRequestFromRequestCacheConfig.repo = Mock(CsrfTokenRepository) - (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.generateToken(_) >> csrfToken - loadConfig(CsrfDisablesPostRequestFromRequestCacheConfig) - request.servletPath = "/some-url" - request.requestURI = "/some-url" - request.method = "GET" + CsrfDisablesPostRequestFromRequestCacheConfig.repo = Mock(CsrfTokenRepository) + (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.generateToken(_) >> csrfToken + loadConfig(CsrfDisablesPostRequestFromRequestCacheConfig) + request.servletPath = "/some-url" + request.requestURI = "/some-url" + request.method = "GET" when: "CSRF passes and our session times out" - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: "sent to the login page" - (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken - response.status == HttpServletResponse.SC_MOVED_TEMPORARILY - response.redirectedUrl == "http://localhost/login" + (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken + response.status == HttpServletResponse.SC_MOVED_TEMPORARILY + response.redirectedUrl == "http://localhost/login" when: "authenticate successfully" - super.setupWeb(request.session) - request.servletPath = "/login" - request.setParameter("username","user") - request.setParameter("password","password") - request.method = "POST" - springSecurityFilterChain.doFilter(request,response,chain) + super.setupWeb(request.session) + request.servletPath = "/login" + request.setParameter("username","user") + request.setParameter("password","password") + request.method = "POST" + springSecurityFilterChain.doFilter(request,response,chain) then: "sent to original URL since it was a GET" - (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken - response.status == HttpServletResponse.SC_MOVED_TEMPORARILY - response.redirectedUrl == "http://localhost/some-url" + (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken + response.status == HttpServletResponse.SC_MOVED_TEMPORARILY + response.redirectedUrl == "http://localhost/some-url" } @EnableWebSecurity @@ -440,18 +436,18 @@ class CsrfConfigurerTests extends BaseSpringSpec { @Override protected void configure(HttpSecurity http) throws Exception { http - .authorizeRequests() + .authorizeRequests() .anyRequest().authenticated() .and() - .formLogin() + .formLogin() .and() - .csrf() + .csrf() .csrfTokenRepository(repo) } @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { auth - .inMemoryAuthentication() + .inMemoryAuthentication() .withUser("user").password("password").roles("USER") } } @@ -463,6 +459,39 @@ class CsrfConfigurerTests extends BaseSpringSpec { thrown(IllegalArgumentException) } + def 'default does not create session'() { + setup: + request = new MockHttpServletRequest(method:'GET') + loadConfig(DefaultDoesNotCreateSession) + when: + springSecurityFilterChain.doFilter(request,response,chain) + then: + request.getSession(false) == null + } + + @EnableWebSecurity(debug=true) + static class DefaultDoesNotCreateSession extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().permitAll() + .and() + .formLogin().and() + .httpBasic(); + // @formatter:on + } + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + auth + .inMemoryAuthentication() + .withUser("user").password("password").roles("USER") + } + } + def clearCsrfToken() { request.removeAllParameters() } diff --git a/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy b/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy index d83fb8ac57..9d8277018c 100644 --- a/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy +++ b/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy @@ -12,12 +12,11 @@ */ package org.springframework.security.config.http -import static org.mockito.Matchers.* -import static org.mockito.Mockito.* - import javax.servlet.http.HttpServletRequest import javax.servlet.http.HttpServletResponse +import spock.lang.Unroll + import org.springframework.mock.web.MockFilterChain import org.springframework.mock.web.MockHttpServletRequest import org.springframework.mock.web.MockHttpServletResponse @@ -36,7 +35,8 @@ import org.springframework.security.web.csrf.DefaultCsrfToken import org.springframework.security.web.util.matcher.RequestMatcher import org.springframework.web.servlet.support.RequestDataValueProcessor -import spock.lang.Unroll +import static org.mockito.Matchers.* +import static org.mockito.Mockito.* /** * @@ -73,256 +73,253 @@ class CsrfConfigTests extends AbstractHttpConfigTests { def 'csrf disabled'() { when: - httpAutoConfig { - csrf(disabled:true) - } - createAppContext() + httpAutoConfig { csrf(disabled:true) } + createAppContext() then: - !getFilter(CsrfFilter) + !getFilter(CsrfFilter) } @Unroll def 'csrf defaults'() { setup: - httpAutoConfig { - 'csrf'() - } - createAppContext() + httpAutoConfig { 'csrf'() } + createAppContext() when: - request.method = httpMethod - springSecurityFilterChain.doFilter(request,response,chain) + request.method = httpMethod + springSecurityFilterChain.doFilter(request,response,chain) then: - response.status == httpStatus + response.status == httpStatus where: - httpMethod | httpStatus - 'POST' | HttpServletResponse.SC_FORBIDDEN - 'PUT' | HttpServletResponse.SC_FORBIDDEN - 'PATCH' | HttpServletResponse.SC_FORBIDDEN - 'DELETE' | HttpServletResponse.SC_FORBIDDEN - 'INVALID' | HttpServletResponse.SC_FORBIDDEN - 'GET' | HttpServletResponse.SC_OK - 'HEAD' | HttpServletResponse.SC_OK - 'TRACE' | HttpServletResponse.SC_OK - 'OPTIONS' | HttpServletResponse.SC_OK + httpMethod | httpStatus + 'POST' | HttpServletResponse.SC_FORBIDDEN + 'PUT' | HttpServletResponse.SC_FORBIDDEN + 'PATCH' | HttpServletResponse.SC_FORBIDDEN + 'DELETE' | HttpServletResponse.SC_FORBIDDEN + 'INVALID' | HttpServletResponse.SC_FORBIDDEN + 'GET' | HttpServletResponse.SC_OK + 'HEAD' | HttpServletResponse.SC_OK + 'TRACE' | HttpServletResponse.SC_OK + 'OPTIONS' | HttpServletResponse.SC_OK } def 'csrf default creates CsrfRequestDataValueProcessor'() { when: - httpAutoConfig { - 'csrf'() - } - createAppContext() + httpAutoConfig { 'csrf'() } + createAppContext() then: - appContext.getBean("requestDataValueProcessor",RequestDataValueProcessor) + appContext.getBean("requestDataValueProcessor",RequestDataValueProcessor) } def 'csrf custom AccessDeniedHandler'() { setup: - httpAutoConfig { - 'access-denied-handler'(ref:'adh') - 'csrf'() - } - mockBean(AccessDeniedHandler,'adh') - createAppContext() - AccessDeniedHandler adh = appContext.getBean(AccessDeniedHandler) - request.method = "POST" + httpAutoConfig { + 'access-denied-handler'(ref:'adh') + 'csrf'() + } + mockBean(AccessDeniedHandler,'adh') + createAppContext() + AccessDeniedHandler adh = appContext.getBean(AccessDeniedHandler) + request.method = "POST" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - verify(adh).handle(any(HttpServletRequest),any(HttpServletResponse),any(AccessDeniedException)) - response.status == HttpServletResponse.SC_OK // our mock doesn't do anything + verify(adh).handle(any(HttpServletRequest),any(HttpServletResponse),any(AccessDeniedException)) + response.status == HttpServletResponse.SC_OK // our mock doesn't do anything } def "csrf disables posts for RequestCache"() { setup: - httpAutoConfig { - 'csrf'('token-repository-ref':'repo') - 'intercept-url'(pattern:"/**",access:'ROLE_USER') - } - mockBean(CsrfTokenRepository,'repo') - createAppContext() - CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") - when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) - when(repo.generateToken(any(HttpServletRequest))).thenReturn(token) - request.setParameter(token.parameterName,token.token) - request.servletPath = "/some-url" - request.requestURI = "/some-url" - request.method = "POST" + httpAutoConfig { + 'csrf'('token-repository-ref':'repo') + 'intercept-url'(pattern:"/**",access:'ROLE_USER') + } + mockBean(CsrfTokenRepository,'repo') + createAppContext() + CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") + when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) + when(repo.generateToken(any(HttpServletRequest))).thenReturn(token) + request.setParameter(token.parameterName,token.token) + request.servletPath = "/some-url" + request.requestURI = "/some-url" + request.method = "POST" when: "CSRF passes and our session times out" - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: "sent to the login page" - response.status == HttpServletResponse.SC_MOVED_TEMPORARILY - response.redirectedUrl == "http://localhost/login" + response.status == HttpServletResponse.SC_MOVED_TEMPORARILY + response.redirectedUrl == "http://localhost/login" when: "authenticate successfully" - response = new MockHttpServletResponse() - request = new MockHttpServletRequest(session: request.session) - request.servletPath = "/login" - request.setParameter(token.parameterName,token.token) - request.setParameter("username","user") - request.setParameter("password","password") - request.method = "POST" - springSecurityFilterChain.doFilter(request,response,chain) + response = new MockHttpServletResponse() + request = new MockHttpServletRequest(session: request.session) + request.servletPath = "/login" + request.setParameter(token.parameterName,token.token) + request.setParameter("username","user") + request.setParameter("password","password") + request.method = "POST" + springSecurityFilterChain.doFilter(request,response,chain) then: "sent to default success because we don't want csrf attempts made prior to authentication to pass" - response.status == HttpServletResponse.SC_MOVED_TEMPORARILY - response.redirectedUrl == "/" + response.status == HttpServletResponse.SC_MOVED_TEMPORARILY + response.redirectedUrl == "/" } def "csrf enables gets for RequestCache"() { setup: - httpAutoConfig { - 'csrf'('token-repository-ref':'repo') - 'intercept-url'(pattern:"/**",access:'ROLE_USER') - } - mockBean(CsrfTokenRepository,'repo') - createAppContext() - CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") - when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) + httpAutoConfig { + 'csrf'('token-repository-ref':'repo') + 'intercept-url'(pattern:"/**",access:'ROLE_USER') + } + mockBean(CsrfTokenRepository,'repo') + createAppContext() + CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") + when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) when(repo.generateToken(any(HttpServletRequest))).thenReturn(token) - request.setParameter(token.parameterName,token.token) - request.servletPath = "/some-url" - request.requestURI = "/some-url" - request.method = "GET" + request.setParameter(token.parameterName,token.token) + request.servletPath = "/some-url" + request.requestURI = "/some-url" + request.method = "GET" when: "CSRF passes and our session times out" - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: "sent to the login page" - response.status == HttpServletResponse.SC_MOVED_TEMPORARILY - response.redirectedUrl == "http://localhost/login" + response.status == HttpServletResponse.SC_MOVED_TEMPORARILY + response.redirectedUrl == "http://localhost/login" when: "authenticate successfully" - response = new MockHttpServletResponse() - request = new MockHttpServletRequest(session: request.session) - request.servletPath = "/login" - request.setParameter(token.parameterName,token.token) - request.setParameter("username","user") - request.setParameter("password","password") - request.method = "POST" - springSecurityFilterChain.doFilter(request,response,chain) + response = new MockHttpServletResponse() + request = new MockHttpServletRequest(session: request.session) + request.servletPath = "/login" + request.setParameter(token.parameterName,token.token) + request.setParameter("username","user") + request.setParameter("password","password") + request.method = "POST" + springSecurityFilterChain.doFilter(request,response,chain) then: "sent to original URL since it was a GET" - response.status == HttpServletResponse.SC_MOVED_TEMPORARILY - response.redirectedUrl == "http://localhost/some-url" + response.status == HttpServletResponse.SC_MOVED_TEMPORARILY + response.redirectedUrl == "http://localhost/some-url" } def "SEC-2422: csrf expire CSRF token and session-management invalid-session-url"() { setup: - httpAutoConfig { - 'csrf'() - 'session-management'('invalid-session-url': '/error/sessionError') - } - createAppContext() - request.setParameter("_csrf","abc") - request.method = "POST" + httpAutoConfig { + 'csrf'() + 'session-management'('invalid-session-url': '/error/sessionError') + } + createAppContext() + request.setParameter("_csrf","abc") + request.method = "POST" when: "No existing expected CsrfToken (session times out) and a POST" - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: "sent to the session timeout page page" - response.status == HttpServletResponse.SC_MOVED_TEMPORARILY - response.redirectedUrl == "/error/sessionError" + response.status == HttpServletResponse.SC_MOVED_TEMPORARILY + response.redirectedUrl == "/error/sessionError" when: "Existing expected CsrfToken and a POST (invalid token provided)" - response = new MockHttpServletResponse() - request = new MockHttpServletRequest(session: request.session, method:'POST') - springSecurityFilterChain.doFilter(request,response,chain) + response = new MockHttpServletResponse() + request = new MockHttpServletRequest(session: request.session, method:'POST') + springSecurityFilterChain.doFilter(request,response,chain) then: "Access Denied occurs" - response.status == HttpServletResponse.SC_FORBIDDEN + response.status == HttpServletResponse.SC_FORBIDDEN } def "csrf requireCsrfProtectionMatcher"() { setup: - httpAutoConfig { - 'csrf'('request-matcher-ref':'matcher') - } - mockBean(RequestMatcher,'matcher') - createAppContext() - request.method = 'POST' - RequestMatcher matcher = appContext.getBean("matcher",RequestMatcher) + httpAutoConfig { 'csrf'('request-matcher-ref':'matcher') } + mockBean(RequestMatcher,'matcher') + createAppContext() + request.method = 'POST' + RequestMatcher matcher = appContext.getBean("matcher",RequestMatcher) when: - when(matcher.matches(any(HttpServletRequest))).thenReturn(false) - springSecurityFilterChain.doFilter(request,response,chain) + when(matcher.matches(any(HttpServletRequest))).thenReturn(false) + springSecurityFilterChain.doFilter(request,response,chain) then: - response.status == HttpServletResponse.SC_OK + response.status == HttpServletResponse.SC_OK when: - when(matcher.matches(any(HttpServletRequest))).thenReturn(true) - springSecurityFilterChain.doFilter(request,response,chain) + when(matcher.matches(any(HttpServletRequest))).thenReturn(true) + springSecurityFilterChain.doFilter(request,response,chain) then: - response.status == HttpServletResponse.SC_FORBIDDEN + response.status == HttpServletResponse.SC_FORBIDDEN + } + + def "csrf csrfTokenRepository default delays save"() { + setup: + httpAutoConfig { + } + createAppContext() + request.method = "GET" + when: + springSecurityFilterChain.doFilter(request,response,chain) + then: + response.status == HttpServletResponse.SC_OK + request.getSession(false) == null } def "csrf csrfTokenRepository"() { setup: - httpAutoConfig { - 'csrf'('token-repository-ref':'repo') - } - mockBean(CsrfTokenRepository,'repo') - createAppContext() - CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") - when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) - request.setParameter(token.parameterName,token.token) - request.method = "POST" + httpAutoConfig { 'csrf'('token-repository-ref':'repo') } + mockBean(CsrfTokenRepository,'repo') + createAppContext() + CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") + when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) + request.setParameter(token.parameterName,token.token) + request.method = "POST" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - response.status == HttpServletResponse.SC_OK + response.status == HttpServletResponse.SC_OK when: - request.setParameter(token.parameterName,token.token+"INVALID") - springSecurityFilterChain.doFilter(request,response,chain) + request.setParameter(token.parameterName,token.token+"INVALID") + springSecurityFilterChain.doFilter(request,response,chain) then: - response.status == HttpServletResponse.SC_FORBIDDEN + response.status == HttpServletResponse.SC_FORBIDDEN } def "csrf clears on login"() { setup: - httpAutoConfig { - 'csrf'('token-repository-ref':'repo') - } - mockBean(CsrfTokenRepository,'repo') - createAppContext() - CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") - when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) - when(repo.generateToken(any(HttpServletRequest))).thenReturn(token) - request.setParameter(token.parameterName,token.token) - request.method = "POST" - request.setParameter("username","user") - request.setParameter("password","password") - request.servletPath = "/login" + httpAutoConfig { 'csrf'('token-repository-ref':'repo') } + mockBean(CsrfTokenRepository,'repo') + createAppContext() + CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") + when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) + when(repo.generateToken(any(HttpServletRequest))).thenReturn(token) + request.setParameter(token.parameterName,token.token) + request.method = "POST" + request.setParameter("username","user") + request.setParameter("password","password") + request.servletPath = "/login" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - verify(repo, atLeastOnce()).saveToken(eq(null),any(HttpServletRequest), any(HttpServletResponse)) + verify(repo, atLeastOnce()).saveToken(eq(null),any(HttpServletRequest), any(HttpServletResponse)) } def "csrf clears on logout"() { setup: - httpAutoConfig { - 'csrf'('token-repository-ref':'repo') - } - mockBean(CsrfTokenRepository,'repo') - createAppContext() - CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") - when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) - request.setParameter(token.parameterName,token.token) - request.method = "POST" - request.servletPath = "/logout" + httpAutoConfig { 'csrf'('token-repository-ref':'repo') } + mockBean(CsrfTokenRepository,'repo') + createAppContext() + CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") + when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) + request.setParameter(token.parameterName,token.token) + request.method = "POST" + request.servletPath = "/logout" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - verify(repo).saveToken(eq(null),any(HttpServletRequest), any(HttpServletResponse)) + verify(repo).saveToken(eq(null),any(HttpServletRequest), any(HttpServletResponse)) } def "SEC-2495: csrf disables logout on GET"() { setup: - httpAutoConfig { - 'csrf'() - } - createAppContext() - login() - request.method = "GET" - request.requestURI = "/logout" + httpAutoConfig { 'csrf'() } + createAppContext() + login() + request.method = "GET" + request.requestURI = "/logout" when: - springSecurityFilterChain.doFilter(request,response,chain) + springSecurityFilterChain.doFilter(request,response,chain) then: - getAuthentication(request) != null + getAuthentication(request) != null } diff --git a/docs/manual/src/docs/asciidoc/index.adoc b/docs/manual/src/docs/asciidoc/index.adoc index 48f8d022ae..ad7b813772 100644 --- a/docs/manual/src/docs/asciidoc/index.adoc +++ b/docs/manual/src/docs/asciidoc/index.adoc @@ -378,6 +378,7 @@ You can find the highlights below: * <> * <> * <> +* <> provides simple AngularJS & CSRF integration * Added `ForwardAuthenticationFailureHandler` & `ForwardAuthenticationSuccessHandler` * SCrypt support with `SCryptPasswordEncoder` * Meta Annotation Support @@ -3252,6 +3253,7 @@ protected void configure(HttpSecurity http) throws Exception { } ---- + [[csrf-include-csrf-token]] ==== Include the CSRF Token @@ -3324,6 +3326,41 @@ name: $("meta[name='_csrf_header']").attr("content") The configured client can be shared with any component of the application that needs to make a request to the CSRF protected resource. One significant different between rest.js and jQuery is that only requests made with the configured client will contain the CSRF token, vs jQuery where __all__ requests will include the token. The ability to scope which requests receive the token helps guard against leaking the CSRF token to a third party. Please refer to the https://github.com/cujojs/rest/tree/master/docs[rest.js reference documentation] for more information on rest.js. +[[csrf-cookie]] +===== CookieCsrfTokenRepository + +There can be cases where users will want to persist the `CsrfToken` in a cookie. +By default the `CookieCsrfTokenRepository` will write to a cookie named `XSRF-TOKEN` and read it from a header named `X-XSRF-TOKEN` or the HTTP parameter `_csrf`. +These defaults come from https://docs.angularjs.org/api/ng/service/$http#cross-site-request-forgery-xsrf-protection[AngularJS] + +You can configure `CookieCsrfTokenRepository` in XML using the following: + +[source,xml] +---- + + + + + +---- + +You can configure `CookieCsrfTokenRepository` in Java Configuration using: + +[source,java] +---- +@EnableWebSecurity +public class WebSecurityConfig extends + WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + http + .csrf() + .csrfTokenRepository(new CookieCsrfTokenRepository()); + } +} +---- + [[csrf-caveats]] === CSRF Caveats @@ -3336,13 +3373,16 @@ One issue is that the expected CSRF token is stored in the HttpSession, so as so [NOTE] ==== -One might ask why the expected `CsrfToken` isn't stored in a cookie. This is because there are known exploits in which headers (i.e. specify the cookies) can be set by another domain. This is the same reason Ruby on Rails http://weblog.rubyonrails.org/2011/2/8/csrf-protection-bypass-in-ruby-on-rails/[no longer skips CSRF checks when the header X-Requested-With is present]. See http://lists.webappsec.org/pipermail/websecurity_lists.webappsec.org/2011-February/007533.html[this webappsec.org thread] for details on how to perform the exploit. Another disadvantage is that by removing the state (i.e. the timeout) you lose the ability to forcibly terminate the token if it is compromised. +One might ask why the expected `CsrfToken` isn't stored in a cookie by default. This is because there are known exploits in which headers (i.e. specify the cookies) can be set by another domain. This is the same reason Ruby on Rails http://weblog.rubyonrails.org/2011/2/8/csrf-protection-bypass-in-ruby-on-rails/[no longer skips CSRF checks when the header X-Requested-With is present]. See http://lists.webappsec.org/pipermail/websecurity_lists.webappsec.org/2011-February/007533.html[this webappsec.org thread] for details on how to perform the exploit. Another disadvantage is that by removing the state (i.e. the timeout) you lose the ability to forcibly terminate the token if it is compromised. ==== A simple way to mitigate an active user experiencing a timeout is to have some JavaScript that lets the user know their session is about to expire. The user can click a button to continue and refresh the session. Alternatively, specifying a custom `AccessDeniedHandler` allows you to process the `InvalidCsrfTokenException` any way you like. For an example of how to customize the `AccessDeniedHandler` refer to the provided links for both <> and https://github.com/spring-projects/spring-security/blob/3.2.0.RC1/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpAccessDeniedHandlerTests.groovy#L64[Java configuration]. +Finally, the application can be configured to use <> which will not expire. +As previously mentioned, this is not as secure as using a session, but in many cases can be good enough. + [[csrf-login]] ==== Logging In diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 8c7cad5ff5..86bcef4779 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -53,6 +53,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.util.Assert; @@ -107,8 +108,8 @@ public final class SecurityMockMvcRequestPostProcessors { * @throws IOException * @throws CertificateException */ - public static RequestPostProcessor x509(String resourceName) throws IOException, - CertificateException { + public static RequestPostProcessor x509(String resourceName) + throws IOException, CertificateException { ResourceLoader loader = new DefaultResourceLoader(); Resource resource = loader.getResource(resourceName); InputStream inputStream = resource.getInputStream(); @@ -142,24 +143,24 @@ public final class SecurityMockMvcRequestPostProcessors { * Establish a {@link SecurityContext} that has a * {@link UsernamePasswordAuthenticationToken} for the * {@link Authentication#getPrincipal()} and a {@link User} for the - * {@link UsernamePasswordAuthenticationToken#getPrincipal()}. All details - * are declarative and do not require that the user actually exists. + * {@link UsernamePasswordAuthenticationToken#getPrincipal()}. All details are + * declarative and do not require that the user actually exists. * *

- * The support works by associating the user to the HttpServletRequest. To - * associate the request to the SecurityContextHolder you need to ensure - * that the SecurityContextPersistenceFilter is associated with the - * MockMvc instance. A few ways to do this are: + * The support works by associating the user to the HttpServletRequest. To associate + * the request to the SecurityContextHolder you need to ensure that the + * SecurityContextPersistenceFilter is associated with the MockMvc instance. A few + * ways to do this are: *

* *
    *
  • Invoking apply {@link SecurityMockMvcConfigurers#springSecurity()}
  • *
  • Adding Spring Security's FilterChainProxy to MockMvc
  • - *
  • Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc instance may make sense when using MockMvcBuilders standaloneSetup
  • + *
  • Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc + * instance may make sense when using MockMvcBuilders standaloneSetup
  • *
* - * @param username - * the username to populate + * @param username the username to populate * @return the {@link UserRequestPostProcessor} for additional customization */ public static UserRequestPostProcessor user(String username) { @@ -174,16 +175,17 @@ public final class SecurityMockMvcRequestPostProcessors { * declarative and do not require that the user actually exists. * *

- * The support works by associating the user to the HttpServletRequest. To - * associate the request to the SecurityContextHolder you need to ensure - * that the SecurityContextPersistenceFilter is associated with the - * MockMvc instance. A few ways to do this are: + * The support works by associating the user to the HttpServletRequest. To associate + * the request to the SecurityContextHolder you need to ensure that the + * SecurityContextPersistenceFilter is associated with the MockMvc instance. A few + * ways to do this are: *

* *
    *
  • Invoking apply {@link SecurityMockMvcConfigurers#springSecurity()}
  • *
  • Adding Spring Security's FilterChainProxy to MockMvc
  • - *
  • Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc instance may make sense when using MockMvcBuilders standaloneSetup
  • + *
  • Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc + * instance may make sense when using MockMvcBuilders standaloneSetup
  • *
* * @param user the UserDetails to populate @@ -199,16 +201,17 @@ public final class SecurityMockMvcRequestPostProcessors { * details are declarative and do not require that the user actually exists. * *

- * The support works by associating the user to the HttpServletRequest. To - * associate the request to the SecurityContextHolder you need to ensure - * that the SecurityContextPersistenceFilter is associated with the - * MockMvc instance. A few ways to do this are: + * The support works by associating the user to the HttpServletRequest. To associate + * the request to the SecurityContextHolder you need to ensure that the + * SecurityContextPersistenceFilter is associated with the MockMvc instance. A few + * ways to do this are: *

* *
    *
  • Invoking apply {@link SecurityMockMvcConfigurers#springSecurity()}
  • *
  • Adding Spring Security's FilterChainProxy to MockMvc
  • - *
  • Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc instance may make sense when using MockMvcBuilders standaloneSetup
  • + *
  • Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc + * instance may make sense when using MockMvcBuilders standaloneSetup
  • *
* * @param authentication the Authentication to populate @@ -220,9 +223,9 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Establish a {@link SecurityContext} that uses an - * {@link AnonymousAuthenticationToken}. This is useful when a user wants to - * run a majority of tests as a specific user and wishes to override a few - * methods to be anonymous. For example: + * {@link AnonymousAuthenticationToken}. This is useful when a user wants to run a + * majority of tests as a specific user and wishes to override a few methods to be + * anonymous. For example: * *
 	 * 
@@ -241,8 +244,7 @@ public final class SecurityMockMvcRequestPostProcessors {
 	 *     }
 	 *     // ... lots of tests ran with a default user ...
 	 * }
-	 * 
-	 * 
+ * * * @return the {@link RequestPostProcessor} to use */ @@ -254,11 +256,10 @@ public final class SecurityMockMvcRequestPostProcessors { * Establish the specified {@link SecurityContext} to be used. * *

- * This works by associating the user to the {@link HttpServletRequest}. To - * associate the request to the {@link SecurityContextHolder} you need to - * ensure that the {@link SecurityContextPersistenceFilter} (i.e. Spring - * Security's FilterChainProxy will typically do this) is associated with - * the {@link MockMvc} instance. + * This works by associating the user to the {@link HttpServletRequest}. To associate + * the request to the {@link SecurityContextHolder} you need to ensure that the + * {@link SecurityContextPersistenceFilter} (i.e. Spring Security's FilterChainProxy + * will typically do this) is associated with the {@link MockMvc} instance. *

*/ public static RequestPostProcessor securityContext(SecurityContext securityContext) { @@ -289,8 +290,10 @@ public final class SecurityMockMvcRequestPostProcessors { this.certificates = certificates; } + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - request.setAttribute("javax.servlet.request.X509Certificate", certificates); + request.setAttribute("javax.servlet.request.X509Certificate", + this.certificates); return request; } } @@ -313,18 +316,20 @@ public final class SecurityMockMvcRequestPostProcessors { * @see org.springframework.test.web.servlet.request.RequestPostProcessor * #postProcessRequest (org.springframework.mock.web.MockHttpServletRequest) */ + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); - if(!(repository instanceof TestCsrfTokenRepository)) { - repository = new TestCsrfTokenRepository(repository); + if (!(repository instanceof TestCsrfTokenRepository)) { + repository = new TestCsrfTokenRepository( + new HttpSessionCsrfTokenRepository()); WebTestUtils.setCsrfTokenRepository(request, repository); } CsrfToken token = repository.generateToken(request); repository.saveToken(token, request, new MockHttpServletResponse()); - String tokenValue = useInvalidToken ? "invalid" + token.getToken() : token - .getToken(); - if (asHeader) { + String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() + : token.getToken(); + if (this.asHeader) { request.addHeader(token.getHeaderName(), tokenValue); } else { @@ -357,16 +362,13 @@ public final class SecurityMockMvcRequestPostProcessors { private CsrfRequestPostProcessor() { } - - /** - * Used to wrap the CsrfTokenRepository to provide support for testing - * when the request is wrapped (i.e. Spring Session is in use). + * Used to wrap the CsrfTokenRepository to provide support for testing when the + * request is wrapped (i.e. Spring Session is in use). */ - static class TestCsrfTokenRepository implements - CsrfTokenRepository { - final static String ATTR_NAME = TestCsrfTokenRepository.class - .getName().concat(".TOKEN"); + static class TestCsrfTokenRepository implements CsrfTokenRepository { + final static String ATTR_NAME = TestCsrfTokenRepository.class.getName() + .concat(".TOKEN"); private final CsrfTokenRepository delegate; @@ -374,14 +376,18 @@ public final class SecurityMockMvcRequestPostProcessors { this.delegate = delegate; } + @Override public CsrfToken generateToken(HttpServletRequest request) { - return delegate.generateToken(request); + return this.delegate.generateToken(request); } - public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { + @Override + public void saveToken(CsrfToken token, HttpServletRequest request, + HttpServletResponse response) { request.setAttribute(ATTR_NAME, token); } + @Override public CsrfToken loadToken(HttpServletRequest request) { return (CsrfToken) request.getAttribute(ATTR_NAME); } @@ -447,14 +453,16 @@ public final class SecurityMockMvcRequestPostProcessors { private String createAuthorizationHeader(MockHttpServletRequest request) { String uri = request.getRequestURI(); - String responseDigest = generateDigest(username, realm, password, - request.getMethod(), uri, qop, nonce, nc, cnonce); - return "Digest username=\"" + username + "\", realm=\"" + realm - + "\", nonce=\"" + nonce + "\", uri=\"" + uri + "\", response=\"" - + responseDigest + "\", qop=" + qop + ", nc=" + nc + ", cnonce=\"" - + cnonce + "\""; + String responseDigest = generateDigest(this.username, this.realm, + this.password, request.getMethod(), uri, this.qop, this.nonce, + this.nc, this.cnonce); + return "Digest username=\"" + this.username + "\", realm=\"" + this.realm + + "\", nonce=\"" + this.nonce + "\", uri=\"" + uri + "\", response=\"" + + responseDigest + "\", qop=" + this.qop + ", nc=" + this.nc + + ", cnonce=\"" + this.cnonce + "\""; } + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { request.addHeader("Authorization", createAuthorizationHeader(request)); @@ -574,8 +582,7 @@ public final class SecurityMockMvcRequestPostProcessors { * Used to wrap the SecurityContextRepository to provide support for testing in * stateless mode */ - static class TestSecurityContextRepository implements - SecurityContextRepository { + static class TestSecurityContextRepository implements SecurityContextRepository { private final static String ATTR_NAME = TestSecurityContextRepository.class .getName().concat(".REPO"); @@ -585,6 +592,7 @@ public final class SecurityMockMvcRequestPostProcessors { this.delegate = delegate; } + @Override public SecurityContext loadContext( HttpRequestResponseHolder requestResponseHolder) { SecurityContext result = getContext(requestResponseHolder.getRequest()); @@ -592,19 +600,22 @@ public final class SecurityMockMvcRequestPostProcessors { // holder are updated // remember the SecurityContextRepository is used in many different // locations - SecurityContext delegateResult = delegate + SecurityContext delegateResult = this.delegate .loadContext(requestResponseHolder); return result == null ? delegateResult : result; } + @Override public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) { request.setAttribute(ATTR_NAME, context); - delegate.saveContext(context, request, response); + this.delegate.saveContext(context, request, response); } + @Override public boolean containsContext(HttpServletRequest request) { - return getContext(request) != null || delegate.containsContext(request); + return getContext(request) != null + || this.delegate.containsContext(request); } private static SecurityContext getContext(HttpServletRequest request) { @@ -625,15 +636,17 @@ public final class SecurityMockMvcRequestPostProcessors { SecurityContextRequestPostProcessorSupport implements RequestPostProcessor { private SecurityContext EMPTY = SecurityContextHolder.createEmptyContext(); + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { // TestSecurityContextHolder is only a default value - SecurityContext existingContext = TestSecurityContextRepository.getContext(request); - if(existingContext != null) { + SecurityContext existingContext = TestSecurityContextRepository + .getContext(request); + if (existingContext != null) { return request; } SecurityContext context = TestSecurityContextHolder.getContext(); - if(!EMPTY.equals(context)) { + if (!this.EMPTY.equals(context)) { save(context, request); } @@ -657,6 +670,7 @@ public final class SecurityMockMvcRequestPostProcessors { this.securityContext = securityContext; } + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { save(this.securityContext, request); return request; @@ -679,10 +693,11 @@ public final class SecurityMockMvcRequestPostProcessors { this.authentication = authentication; } + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { SecurityContext context = SecurityContextHolder.createEmptyContext(); - context.setAuthentication(authentication); - save(authentication, request); + context.setAuthentication(this.authentication); + save(this.authentication, request); return request; } } @@ -695,19 +710,20 @@ public final class SecurityMockMvcRequestPostProcessors { * @author Rob Winch * @since 4.0 */ - private final static class UserDetailsRequestPostProcessor implements - RequestPostProcessor { + private final static class UserDetailsRequestPostProcessor + implements RequestPostProcessor { private final RequestPostProcessor delegate; public UserDetailsRequestPostProcessor(UserDetails user) { Authentication token = new UsernamePasswordAuthenticationToken(user, user.getPassword(), user.getAuthorities()); - delegate = new AuthenticationRequestPostProcessor(token); + this.delegate = new AuthenticationRequestPostProcessor(token); } + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - return delegate.postProcessRequest(request); + return this.delegate.postProcessRequest(request); } } @@ -752,8 +768,8 @@ public final class SecurityMockMvcRequestPostProcessors { * {@link #authorities(GrantedAuthority...)}, but just not as flexible. * * @param roles The roles to populate. Note that if the role does not start with - * {@link #ROLE_PREFIX} it will automatically be prepended. This means by - * default {@code roles("ROLE_USER")} and {@code roles("USER")} are equivalent. + * {@link #ROLE_PREFIX} it will automatically be prepended. This means by default + * {@code roles("ROLE_USER")} and {@code roles("USER")} are equivalent. * @see #authorities(GrantedAuthority...) * @see #ROLE_PREFIX * @return the UserRequestPostProcessor for further customizations @@ -764,8 +780,7 @@ public final class SecurityMockMvcRequestPostProcessors { for (String role : roles) { if (role.startsWith(ROLE_PREFIX)) { throw new IllegalArgumentException( - "Role should not start with " - + ROLE_PREFIX + "Role should not start with " + ROLE_PREFIX + " since this method automatically prefixes with this value. Got " + role); } @@ -812,6 +827,7 @@ public final class SecurityMockMvcRequestPostProcessors { return this; } + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { UserDetailsRequestPostProcessor delegate = new UserDetailsRequestPostProcessor( createUser()); @@ -823,19 +839,27 @@ public final class SecurityMockMvcRequestPostProcessors { * @return the {@link User} for the principal */ private User createUser() { - return new User(username, password, enabled, accountNonExpired, - credentialsNonExpired, accountNonLocked, authorities); + return new User(this.username, this.password, this.enabled, + this.accountNonExpired, this.credentialsNonExpired, + this.accountNonLocked, this.authorities); } } - private static class AnonymousRequestPostProcessor extends SecurityContextRequestPostProcessorSupport implements RequestPostProcessor { - private AuthenticationRequestPostProcessor delegate = new AuthenticationRequestPostProcessor(new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))); + private static class AnonymousRequestPostProcessor extends + SecurityContextRequestPostProcessorSupport implements RequestPostProcessor { + private AuthenticationRequestPostProcessor delegate = new AuthenticationRequestPostProcessor( + new AnonymousAuthenticationToken("key", "anonymous", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))); - /* (non-Javadoc) - * @see org.springframework.test.web.servlet.request.RequestPostProcessor#postProcessRequest(org.springframework.mock.web.MockHttpServletRequest) + /* + * (non-Javadoc) + * + * @see org.springframework.test.web.servlet.request.RequestPostProcessor# + * postProcessRequest(org.springframework.mock.web.MockHttpServletRequest) */ + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - return delegate.postProcessRequest(request); + return this.delegate.postProcessRequest(request); } } @@ -853,8 +877,9 @@ public final class SecurityMockMvcRequestPostProcessors { this.headerValue = "Basic " + new String(Base64.encode(toEncode)); } + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - request.addHeader("Authorization", headerValue); + request.addHeader("Authorization", this.headerValue); return request; } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java new file mode 100644 index 0000000000..1a85dd9761 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java @@ -0,0 +1,126 @@ +/* + * Copyright 2012-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import java.util.UUID; + +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.util.WebUtils; + +/** + * A {@link CsrfTokenRepository} that persist the CSRF token in a cookie named + * "XSRF-TOKEN" and reads from the header "X-XSRF-TOKEN" following the conventions of + * AngularJS. + * + * @author Rob Winch + * @since 4.1 + */ +public final class CookieCsrfTokenRepository implements CsrfTokenRepository { + static final String DEFAULT_CSRF_COOKIE_NAME = "XSRF-TOKEN"; + + static final String DEFAULT_CSRF_PARAMETER_NAME = "_csrf"; + + static final String DEFAULT_CSRF_HEADER_NAME = "X-XSRF-TOKEN"; + + private String parameterName = DEFAULT_CSRF_PARAMETER_NAME; + + private String headerName = DEFAULT_CSRF_HEADER_NAME; + + private String cookieName = DEFAULT_CSRF_COOKIE_NAME; + + @Override + public CsrfToken generateToken(HttpServletRequest request) { + return new DefaultCsrfToken(this.headerName, this.parameterName, + createNewToken()); + } + + @Override + public void saveToken(CsrfToken token, HttpServletRequest request, + HttpServletResponse response) { + String tokenValue = token == null ? "" : token.getToken(); + Cookie cookie = new Cookie(this.cookieName, tokenValue); + cookie.setSecure(request.isSecure()); + cookie.setPath(getCookiePath(request)); + if (token == null) { + cookie.setMaxAge(0); + } + else { + cookie.setMaxAge(-1); + } + response.addCookie(cookie); + } + + @Override + public CsrfToken loadToken(HttpServletRequest request) { + Cookie cookie = WebUtils.getCookie(request, this.cookieName); + if (cookie == null) { + return null; + } + String token = cookie.getValue(); + if (!StringUtils.hasLength(token)) { + return null; + } + return new DefaultCsrfToken(this.headerName, this.parameterName, token); + } + + /** + * Sets the name of the HTTP request parameter that should be used to provide a token. + * + * @param parameterName the name of the HTTP request parameter that should be used to + * provide a token + */ + public void setParameterName(String parameterName) { + Assert.notNull(parameterName, "parameterName is not null"); + this.parameterName = parameterName; + } + + /** + * Sets the name of the HTTP header that should be used to provide the token + * + * @param headerName the name of the HTTP header that should be used to provide the + * token + */ + public void setHeaderName(String headerName) { + Assert.notNull(headerName, "headerName is not null"); + this.headerName = headerName; + } + + /** + * Sets the name of the cookie that the expected CSRF token is saved to and read from + * + * @param cookieName the name of the cookie that the expected CSRF token is saved to + * and read from + */ + public void setCookieName(String cookieName) { + Assert.notNull(cookieName, "cookieName is not null"); + this.cookieName = cookieName; + } + + private String getCookiePath(HttpServletRequest request) { + String contextPath = request.getContextPath(); + return contextPath.length() > 0 ? contextPath : "/"; + } + + private String createNewToken() { + return UUID.randomUUID().toString(); + } +} \ No newline at end of file diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java index 7a9fc862db..fdf1e2c8c1 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java @@ -52,6 +52,7 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt * #onAuthentication(org.springframework.security.core.Authentication, * javax.servlet.http.HttpServletRequest, javax.servlet.http.HttpServletResponse) */ + @Override public void onAuthentication(Authentication authentication, HttpServletRequest request, HttpServletResponse response) throws SessionAuthenticationException { @@ -60,96 +61,10 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt this.csrfTokenRepository.saveToken(null, request, response); CsrfToken newToken = this.csrfTokenRepository.generateToken(request); - CsrfToken tokenForRequest = new SaveOnAccessCsrfToken( - this.csrfTokenRepository, request, response, newToken); + this.csrfTokenRepository.saveToken(newToken, request, response); - request.setAttribute(CsrfToken.class.getName(), tokenForRequest); - request.setAttribute(newToken.getParameterName(), tokenForRequest); + request.setAttribute(CsrfToken.class.getName(), newToken); + request.setAttribute(newToken.getParameterName(), newToken); } } - - private static final class SaveOnAccessCsrfToken implements CsrfToken { - private transient CsrfTokenRepository tokenRepository; - private transient HttpServletRequest request; - private transient HttpServletResponse response; - - private final CsrfToken delegate; - - public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository, - HttpServletRequest request, HttpServletResponse response, - CsrfToken delegate) { - super(); - this.tokenRepository = tokenRepository; - this.request = request; - this.response = response; - this.delegate = delegate; - } - - public String getHeaderName() { - return this.delegate.getHeaderName(); - } - - public String getParameterName() { - return this.delegate.getParameterName(); - } - - public String getToken() { - saveTokenIfNecessary(); - return this.delegate.getToken(); - } - - @Override - public String toString() { - return "SaveOnAccessCsrfToken [delegate=" + this.delegate + "]"; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result - + ((this.delegate == null) ? 0 : this.delegate.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { - return false; - } - SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj; - if (this.delegate == null) { - if (other.delegate != null) { - return false; - } - } - else if (!this.delegate.equals(other.delegate)) { - return false; - } - return true; - } - - private void saveTokenIfNecessary() { - if (this.tokenRepository == null) { - return; - } - - synchronized (this) { - if (this.tokenRepository != null) { - this.tokenRepository.saveToken(this.delegate, this.request, - this.response); - this.tokenRepository = null; - this.request = null; - this.response = null; - } - } - } - - } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 03c155178e..0a62e303bc 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -27,27 +27,29 @@ import javax.servlet.http.HttpSession; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandlerImpl; -import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.UrlUtils; +import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; /** *

- * Applies CSRF protection using a synchronizer token pattern. Developers are required to - * ensure that {@link CsrfFilter} is invoked for any request that allows state to change. - * Typically this just means that they should ensure their web application follows proper - * REST semantics (i.e. do not change state with the HTTP methods GET, HEAD, TRACE, - * OPTIONS). + * Applies + * CSRF + * protection using a synchronizer token pattern. Developers are required to ensure that + * {@link CsrfFilter} is invoked for any request that allows state to change. Typically + * this just means that they should ensure their web application follows proper REST + * semantics (i.e. do not change state with the HTTP methods GET, HEAD, TRACE, OPTIONS). *

* *

* Typically the {@link CsrfTokenRepository} implementation chooses to store the - * {@link CsrfToken} in {@link HttpSession} with {@link HttpSessionCsrfTokenRepository}. - * This is preferred to storing the token in a cookie which can be modified by a client application. + * {@link CsrfToken} in {@link HttpSession} with {@link HttpSessionCsrfTokenRepository} + * wrapped by a {@link LazyCsrfTokenRepository}. This is preferred to storing the token in + * a cookie which can be modified by a client application. *

* * @author Rob Winch @@ -82,18 +84,19 @@ public final class CsrfFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) - throws ServletException, IOException { - CsrfToken csrfToken = tokenRepository.loadToken(request); + throws ServletException, IOException { + request.setAttribute(HttpServletResponse.class.getName(), response); + + CsrfToken csrfToken = this.tokenRepository.loadToken(request); final boolean missingToken = csrfToken == null; if (missingToken) { - CsrfToken generatedToken = tokenRepository.generateToken(request); - csrfToken = new SaveOnAccessCsrfToken(tokenRepository, request, response, - generatedToken); + csrfToken = this.tokenRepository.generateToken(request); + this.tokenRepository.saveToken(csrfToken, request, response); } request.setAttribute(CsrfToken.class.getName(), csrfToken); request.setAttribute(csrfToken.getParameterName(), csrfToken); - if (!requireCsrfProtectionMatcher.matches(request)) { + if (!this.requireCsrfProtectionMatcher.matches(request)) { filterChain.doFilter(request, response); return; } @@ -103,16 +106,16 @@ public final class CsrfFilter extends OncePerRequestFilter { actualToken = request.getParameter(csrfToken.getParameterName()); } if (!csrfToken.getToken().equals(actualToken)) { - if (logger.isDebugEnabled()) { - logger.debug("Invalid CSRF token found for " + if (this.logger.isDebugEnabled()) { + this.logger.debug("Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)); } if (missingToken) { - accessDeniedHandler.handle(request, response, + this.accessDeniedHandler.handle(request, response, new MissingCsrfTokenException(actualToken)); } else { - accessDeniedHandler.handle(request, response, + this.accessDeniedHandler.handle(request, response, new InvalidCsrfTokenException(csrfToken, actualToken)); } return; @@ -156,87 +159,9 @@ public final class CsrfFilter extends OncePerRequestFilter { this.accessDeniedHandler = accessDeniedHandler; } - @SuppressWarnings("serial") - private static final class SaveOnAccessCsrfToken implements CsrfToken { - private transient CsrfTokenRepository tokenRepository; - private transient HttpServletRequest request; - private transient HttpServletResponse response; - - private final CsrfToken delegate; - - public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository, - HttpServletRequest request, HttpServletResponse response, - CsrfToken delegate) { - super(); - this.tokenRepository = tokenRepository; - this.request = request; - this.response = response; - this.delegate = delegate; - } - - public String getHeaderName() { - return delegate.getHeaderName(); - } - - public String getParameterName() { - return delegate.getParameterName(); - } - - public String getToken() { - saveTokenIfNecessary(); - return delegate.getToken(); - } - - @Override - public String toString() { - return "SaveOnAccessCsrfToken [delegate=" + delegate + "]"; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((delegate == null) ? 0 : delegate.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj; - if (delegate == null) { - if (other.delegate != null) - return false; - } - else if (!delegate.equals(other.delegate)) - return false; - return true; - } - - private void saveTokenIfNecessary() { - if (this.tokenRepository == null) { - return; - } - - synchronized (this) { - if (tokenRepository != null) { - this.tokenRepository.saveToken(delegate, request, response); - this.tokenRepository = null; - this.request = null; - this.response = null; - } - } - } - - } - private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { - private final HashSet allowedMethods = new HashSet(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS")); + private final HashSet allowedMethods = new HashSet( + Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS")); /* * (non-Javadoc) @@ -245,8 +170,9 @@ public final class CsrfFilter extends OncePerRequestFilter { * org.springframework.security.web.util.matcher.RequestMatcher#matches(javax. * servlet.http.HttpServletRequest) */ + @Override public boolean matches(HttpServletRequest request) { - return !allowedMethods.contains(request.getMethod()); + return !this.allowedMethods.contains(request.getMethod()); } } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java new file mode 100644 index 0000000000..5d726b7f38 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java @@ -0,0 +1,186 @@ +/* + * Copyright 2012-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.util.Assert; + +/** + * A {@link CsrfTokenRepository} that delays saving new {@link CsrfToken} until the + * attributes of the {@link CsrfToken} that were generated are accessed. + * + * @author Rob Winch + * @since 4.1 + */ +public final class LazyCsrfTokenRepository implements CsrfTokenRepository { + /** + * The {@link HttpServletRequest} attribute name that the {@link HttpServletResponse} + * must be on. + */ + private static final String HTTP_RESPONSE_ATTR = HttpServletResponse.class.getName(); + + private final CsrfTokenRepository delegate; + + /** + * Creates a new instance + * @param delegate the {@link CsrfTokenRepository} to use. Cannot be null + * @throws IllegalArgumentException if delegate is null. + */ + public LazyCsrfTokenRepository(CsrfTokenRepository delegate) { + Assert.notNull(delegate, "delegate cannot be null"); + this.delegate = delegate; + } + + /** + * Generates a new token + * @param request the {@link HttpServletRequest} to use. The + * {@link HttpServletRequest} must have the {@link HttpServletResponse} as an + * attribute with the name of HttpServletResponse.class.getName() + */ + @Override + public CsrfToken generateToken(HttpServletRequest request) { + return wrap(request, this.delegate.generateToken(request)); + } + + /** + * Does nothing if the {@link CsrfToken} is not null. Saving is done only when the + * {@link CsrfToken#getToken()} iis accessed from + * {@link #generateToken(HttpServletRequest)}. If it is null, then the save is + * performed immediately. + */ + @Override + public void saveToken(CsrfToken token, HttpServletRequest request, + HttpServletResponse response) { + if (token == null) { + this.delegate.saveToken(token, request, response); + } + } + + /** + * Delegates to the injected {@link CsrfTokenRepository} + */ + @Override + public CsrfToken loadToken(HttpServletRequest request) { + return this.delegate.loadToken(request); + } + + private CsrfToken wrap(HttpServletRequest request, CsrfToken token) { + HttpServletResponse response = getResponse(request); + return new SaveOnAccessCsrfToken(this.delegate, request, response, token); + } + + private HttpServletResponse getResponse(HttpServletRequest request) { + HttpServletResponse response = (HttpServletResponse) request + .getAttribute(HTTP_RESPONSE_ATTR); + if (response == null) { + throw new IllegalArgumentException( + "The HttpServletRequest attribute must contain an HttpServletResponse for the attribute " + + HTTP_RESPONSE_ATTR); + } + return response; + } + + private static final class SaveOnAccessCsrfToken implements CsrfToken { + private transient CsrfTokenRepository tokenRepository; + private transient HttpServletRequest request; + private transient HttpServletResponse response; + + private final CsrfToken delegate; + + SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository, + HttpServletRequest request, HttpServletResponse response, + CsrfToken delegate) { + super(); + this.tokenRepository = tokenRepository; + this.request = request; + this.response = response; + this.delegate = delegate; + } + + @Override + public String getHeaderName() { + return this.delegate.getHeaderName(); + } + + @Override + public String getParameterName() { + return this.delegate.getParameterName(); + } + + @Override + public String getToken() { + saveTokenIfNecessary(); + return this.delegate.getToken(); + } + + @Override + public String toString() { + return "SaveOnAccessCsrfToken [delegate=" + this.delegate + "]"; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + + ((this.delegate == null) ? 0 : this.delegate.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj; + if (this.delegate == null) { + if (other.delegate != null) { + return false; + } + } + else if (!this.delegate.equals(other.delegate)) { + return false; + } + return true; + } + + private void saveTokenIfNecessary() { + if (this.tokenRepository == null) { + return; + } + + synchronized (this) { + if (this.tokenRepository != null) { + this.tokenRepository.saveToken(this.delegate, this.request, + this.response); + this.tokenRepository = null; + this.request = null; + this.response = null; + } + } + } + + } +} diff --git a/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java new file mode 100644 index 0000000000..c0d3ec8e98 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java @@ -0,0 +1,189 @@ +/* + * Copyright 2012-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import javax.servlet.http.Cookie; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Rob Winch + * @since 4.1 + */ +public class CookieCsrfTokenRepositoryTests { + CookieCsrfTokenRepository repository; + MockHttpServletResponse response; + MockHttpServletRequest request; + + @Before + public void setup() { + this.repository = new CookieCsrfTokenRepository(); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + this.request.setContextPath("/context"); + } + + @Test + public void generateToken() { + CsrfToken generateToken = this.repository.generateToken(this.request); + + assertThat(generateToken).isNotNull(); + assertThat(generateToken.getHeaderName()) + .isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_HEADER_NAME); + assertThat(generateToken.getParameterName()) + .isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_PARAMETER_NAME); + assertThat(generateToken.getToken()).isNotEmpty(); + } + + @Test + public void generateTokenCustom() { + String headerName = "headerName"; + String parameterName = "paramName"; + this.repository.setHeaderName(headerName); + this.repository.setParameterName(parameterName); + + CsrfToken generateToken = this.repository.generateToken(this.request); + + assertThat(generateToken).isNotNull(); + assertThat(generateToken.getHeaderName()).isEqualTo(headerName); + assertThat(generateToken.getParameterName()).isEqualTo(parameterName); + assertThat(generateToken.getToken()).isNotEmpty(); + } + + @Test + public void saveToken() { + CsrfToken token = this.repository.generateToken(this.request); + this.repository.saveToken(token, this.request, this.response); + + Cookie tokenCookie = this.response + .getCookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME); + + assertThat(tokenCookie.getMaxAge()).isEqualTo(-1); + assertThat(tokenCookie.getName()) + .isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME); + assertThat(tokenCookie.getPath()).isEqualTo(this.request.getContextPath()); + assertThat(tokenCookie.getSecure()).isEqualTo(this.request.isSecure()); + assertThat(tokenCookie.getValue()).isEqualTo(token.getToken()); + } + + @Test + public void saveTokenSecure() { + this.request.setSecure(true); + CsrfToken token = this.repository.generateToken(this.request); + this.repository.saveToken(token, this.request, this.response); + + Cookie tokenCookie = this.response + .getCookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME); + + assertThat(tokenCookie.getSecure()).isTrue(); + } + + @Test + public void saveTokenNull() { + this.request.setSecure(true); + this.repository.saveToken(null, this.request, this.response); + + Cookie tokenCookie = this.response + .getCookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME); + + assertThat(tokenCookie.getMaxAge()).isEqualTo(0); + assertThat(tokenCookie.getName()) + .isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME); + assertThat(tokenCookie.getPath()).isEqualTo(this.request.getContextPath()); + assertThat(tokenCookie.getSecure()).isEqualTo(this.request.isSecure()); + assertThat(tokenCookie.getValue()).isEmpty(); + } + + @Test + public void loadTokenNoCookiesNull() { + assertThat(this.repository.loadToken(this.request)).isNull(); + } + + @Test + public void loadTokenCookieIncorrectNameNull() { + this.request.setCookies(new Cookie("other", "name")); + + assertThat(this.repository.loadToken(this.request)).isNull(); + } + + @Test + public void loadTokenCookieValueEmptyString() { + this.request.setCookies( + new Cookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME, "")); + + assertThat(this.repository.loadToken(this.request)).isNull(); + } + + @Test + public void loadToken() { + CsrfToken generateToken = this.repository.generateToken(this.request); + + this.request + .setCookies(new Cookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME, + generateToken.getToken())); + + CsrfToken loadToken = this.repository.loadToken(this.request); + + assertThat(loadToken).isNotNull(); + assertThat(loadToken.getHeaderName()).isEqualTo(generateToken.getHeaderName()); + assertThat(loadToken.getParameterName()) + .isEqualTo(generateToken.getParameterName()); + assertThat(loadToken.getToken()).isNotEmpty(); + } + + @Test + public void loadTokenCustom() { + String cookieName = "cookieName"; + String value = "value"; + String headerName = "headerName"; + String parameterName = "paramName"; + this.repository.setHeaderName(headerName); + this.repository.setParameterName(parameterName); + this.repository.setCookieName(cookieName); + + this.request.setCookies(new Cookie(cookieName, value)); + + CsrfToken loadToken = this.repository.loadToken(this.request); + + assertThat(loadToken).isNotNull(); + assertThat(loadToken.getHeaderName()).isEqualTo(headerName); + assertThat(loadToken.getParameterName()).isEqualTo(parameterName); + assertThat(loadToken.getToken()).isEqualTo(value); + } + + @Test(expected = IllegalArgumentException.class) + public void setCookieNameNullIllegalArgumentException() { + this.repository.setCookieName(null); + } + + @Test(expected = IllegalArgumentException.class) + public void setParameterNameNullIllegalArgumentException() { + this.repository.setParameterName(null); + } + + @Test(expected = IllegalArgumentException.class) + public void setHeaderNameNullIllegalArgumentException() { + this.repository.setHeaderName(null); + } + +} diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java index 0143a1fcdc..d68949aaa0 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java @@ -15,13 +15,6 @@ */ package org.springframework.security.web.csrf; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -30,10 +23,18 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + /** * @author Rob Winch * @@ -55,11 +56,12 @@ public class CsrfAuthenticationStrategyTests { @Before public void setup() { - request = new MockHttpServletRequest(); - response = new MockHttpServletResponse(); - strategy = new CsrfAuthenticationStrategy(csrfTokenRepository); - existingToken = new DefaultCsrfToken("_csrf", "_csrf", "1"); - generatedToken = new DefaultCsrfToken("_csrf", "_csrf", "2"); + this.response = new MockHttpServletResponse(); + this.request = new MockHttpServletRequest(); + this.request.setAttribute(HttpServletResponse.class.getName(), this.response); + this.strategy = new CsrfAuthenticationStrategy(this.csrfTokenRepository); + this.existingToken = new DefaultCsrfToken("_csrf", "_csrf", "1"); + this.generatedToken = new DefaultCsrfToken("_csrf", "_csrf", "2"); } @Test(expected = IllegalArgumentException.class) @@ -69,51 +71,61 @@ public class CsrfAuthenticationStrategyTests { @Test public void logoutRemovesCsrfTokenAndSavesNew() { - when(csrfTokenRepository.loadToken(request)).thenReturn(existingToken); - when(csrfTokenRepository.generateToken(request)).thenReturn(generatedToken); - strategy.onAuthentication(new TestingAuthenticationToken("user", "password", - "ROLE_USER"), request, response); + when(this.csrfTokenRepository.loadToken(this.request)) + .thenReturn(this.existingToken); + when(this.csrfTokenRepository.generateToken(this.request)) + .thenReturn(this.generatedToken); + this.strategy.onAuthentication( + new TestingAuthenticationToken("user", "password", "ROLE_USER"), + this.request, this.response); - verify(csrfTokenRepository).saveToken(null, request, response); - verify(csrfTokenRepository, never()).saveToken(eq(generatedToken), + verify(this.csrfTokenRepository).saveToken(null, this.request, this.response); + verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class)); // SEC-2404, SEC-2832 - CsrfToken tokenInRequest = (CsrfToken) request.getAttribute(CsrfToken.class - .getName()); - assertThat(tokenInRequest.getToken()).isSameAs(generatedToken.getToken()); - assertThat(tokenInRequest.getHeaderName()).isSameAs( - generatedToken.getHeaderName()); - assertThat(tokenInRequest.getParameterName()).isSameAs( - generatedToken.getParameterName()); - assertThat(request.getAttribute(generatedToken.getParameterName())).isSameAs( - tokenInRequest); + CsrfToken tokenInRequest = (CsrfToken) this.request + .getAttribute(CsrfToken.class.getName()); + assertThat(tokenInRequest.getToken()).isSameAs(this.generatedToken.getToken()); + assertThat(tokenInRequest.getHeaderName()) + .isSameAs(this.generatedToken.getHeaderName()); + assertThat(tokenInRequest.getParameterName()) + .isSameAs(this.generatedToken.getParameterName()); + assertThat(this.request.getAttribute(this.generatedToken.getParameterName())) + .isSameAs(tokenInRequest); } // SEC-2872 @Test public void delaySavingCsrf() { - when(csrfTokenRepository.loadToken(request)).thenReturn(existingToken); - when(csrfTokenRepository.generateToken(request)).thenReturn(generatedToken); - strategy.onAuthentication(new TestingAuthenticationToken("user", "password", - "ROLE_USER"), request, response); + this.strategy = new CsrfAuthenticationStrategy( + new LazyCsrfTokenRepository(this.csrfTokenRepository)); - verify(csrfTokenRepository).saveToken(null, request, response); - verify(csrfTokenRepository, never()).saveToken(eq(generatedToken), + when(this.csrfTokenRepository.loadToken(this.request)) + .thenReturn(this.existingToken); + when(this.csrfTokenRepository.generateToken(this.request)) + .thenReturn(this.generatedToken); + this.strategy.onAuthentication( + new TestingAuthenticationToken("user", "password", "ROLE_USER"), + this.request, this.response); + + verify(this.csrfTokenRepository).saveToken(null, this.request, this.response); + verify(this.csrfTokenRepository, never()).saveToken(eq(this.generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class)); - CsrfToken tokenInRequest = (CsrfToken) request.getAttribute(CsrfToken.class - .getName()); + CsrfToken tokenInRequest = (CsrfToken) this.request + .getAttribute(CsrfToken.class.getName()); tokenInRequest.getToken(); - verify(csrfTokenRepository).saveToken(eq(generatedToken), + verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void logoutRemovesNoActionIfNullToken() { - strategy.onAuthentication(new TestingAuthenticationToken("user", "password", - "ROLE_USER"), request, response); + this.strategy.onAuthentication( + new TestingAuthenticationToken("user", "password", "ROLE_USER"), + this.request, this.response); - verify(csrfTokenRepository, never()).saveToken(any(CsrfToken.class), + verify(this.csrfTokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index 5fe68038a7..970276788c 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -15,14 +15,6 @@ */ package org.springframework.security.web.csrf; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; - import java.io.IOException; import java.util.Arrays; @@ -38,11 +30,21 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.util.matcher.RequestMatcher; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + /** * @author Rob Winch * @@ -67,16 +69,21 @@ public class CsrfFilterTests { @Before public void setup() { - token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); + this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); resetRequestResponse(); - filter = new CsrfFilter(tokenRepository); - filter.setRequireCsrfProtectionMatcher(requestMatcher); - filter.setAccessDeniedHandler(deniedHandler); + this.filter = createCsrfFilter(this.tokenRepository); + } + + private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) { + CsrfFilter filter = new CsrfFilter(repository); + filter.setRequireCsrfProtectionMatcher(this.requestMatcher); + filter.setAccessDeniedHandler(this.deniedHandler); + return filter; } private void resetRequestResponse() { - request = new MockHttpServletRequest(); - response = new MockHttpServletResponse(); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); } @Test(expected = IllegalArgumentException.class) @@ -86,282 +93,319 @@ public class CsrfFilterTests { // SEC-2276 @Test - public void doFilterDoesNotSaveCsrfTokenUntilAccessed() throws ServletException, - IOException { - when(requestMatcher.matches(request)).thenReturn(false); - when(tokenRepository.generateToken(request)).thenReturn(token); + public void doFilterDoesNotSaveCsrfTokenUntilAccessed() + throws ServletException, IOException { + this.filter = createCsrfFilter(new LazyCsrfTokenRepository(this.tokenRepository)); + when(this.requestMatcher.matches(this.request)).thenReturn(false); + when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token); - filter.doFilter(request, response, filterChain); - CsrfToken attrToken = (CsrfToken) request.getAttribute(token.getParameterName()); + this.filter.doFilter(this.request, this.response, this.filterChain); + CsrfToken attrToken = (CsrfToken) this.request + .getAttribute(this.token.getParameterName()); // no CsrfToken should have been saved yet - verify(tokenRepository, times(0)).saveToken(any(CsrfToken.class), + verify(this.tokenRepository, times(0)).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); - verify(filterChain).doFilter(request, response); + verify(this.filterChain).doFilter(this.request, this.response); // access the token attrToken.getToken(); // now the CsrfToken should have been saved - verify(tokenRepository).saveToken(eq(token), any(HttpServletRequest.class), - any(HttpServletResponse.class)); + verify(this.tokenRepository).saveToken(eq(this.token), + any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test - public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { - when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.loadToken(request)).thenReturn(token); + public void doFilterAccessDeniedNoTokenPresent() + throws ServletException, IOException { + when(this.requestMatcher.matches(this.request)).thenReturn(true); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertThat(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - verify(deniedHandler).handle(eq(request), eq(response), + verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); - verifyZeroInteractions(filterChain); + verifyZeroInteractions(this.filterChain); } @Test - public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, - IOException { - when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.loadToken(request)).thenReturn(token); - request.setParameter(token.getParameterName(), token.getToken() + " INVALID"); + public void doFilterAccessDeniedIncorrectTokenPresent() + throws ServletException, IOException { + when(this.requestMatcher.matches(this.request)).thenReturn(true); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); + this.request.setParameter(this.token.getParameterName(), + this.token.getToken() + " INVALID"); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertThat(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - verify(deniedHandler).handle(eq(request), eq(response), + verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); - verifyZeroInteractions(filterChain); + verifyZeroInteractions(this.filterChain); } @Test public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { - when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.loadToken(request)).thenReturn(token); - request.addHeader(token.getHeaderName(), token.getToken() + " INVALID"); + when(this.requestMatcher.matches(this.request)).thenReturn(true); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); + this.request.addHeader(this.token.getHeaderName(), + this.token.getToken() + " INVALID"); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertThat(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - verify(deniedHandler).handle(eq(request), eq(response), + verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); - verifyZeroInteractions(filterChain); + verifyZeroInteractions(this.filterChain); } @Test public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException { - when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.loadToken(request)).thenReturn(token); - request.setParameter(token.getParameterName(), token.getToken()); - request.addHeader(token.getHeaderName(), token.getToken() + " INVALID"); + when(this.requestMatcher.matches(this.request)).thenReturn(true); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); + this.request.setParameter(this.token.getParameterName(), this.token.getToken()); + this.request.addHeader(this.token.getHeaderName(), + this.token.getToken() + " INVALID"); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertThat(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - verify(deniedHandler).handle(eq(request), eq(response), + verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); - verifyZeroInteractions(filterChain); + verifyZeroInteractions(this.filterChain); } @Test - public void doFilterNotCsrfRequestExistingToken() throws ServletException, - IOException { - when(requestMatcher.matches(request)).thenReturn(false); - when(tokenRepository.loadToken(request)).thenReturn(token); + public void doFilterNotCsrfRequestExistingToken() + throws ServletException, IOException { + when(this.requestMatcher.matches(this.request)).thenReturn(false); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertThat(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - verify(filterChain).doFilter(request, response); - verifyZeroInteractions(deniedHandler); + verify(this.filterChain).doFilter(this.request, this.response); + verifyZeroInteractions(this.deniedHandler); } @Test - public void doFilterNotCsrfRequestGenerateToken() throws ServletException, - IOException { - when(requestMatcher.matches(request)).thenReturn(false); - when(tokenRepository.generateToken(request)).thenReturn(token); + public void doFilterNotCsrfRequestGenerateToken() + throws ServletException, IOException { + when(this.requestMatcher.matches(this.request)).thenReturn(false); + when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertToken(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertToken(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertToken(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - verify(filterChain).doFilter(request, response); - verifyZeroInteractions(deniedHandler); + verify(this.filterChain).doFilter(this.request, this.response); + verifyZeroInteractions(this.deniedHandler); } @Test - public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, - IOException { - when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.loadToken(request)).thenReturn(token); - request.addHeader(token.getHeaderName(), token.getToken()); + public void doFilterIsCsrfRequestExistingTokenHeader() + throws ServletException, IOException { + when(this.requestMatcher.matches(this.request)).thenReturn(true); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); + this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertThat(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - verify(filterChain).doFilter(request, response); - verifyZeroInteractions(deniedHandler); + verify(this.filterChain).doFilter(this.request, this.response); + verifyZeroInteractions(this.deniedHandler); } @Test public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException { - when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.loadToken(request)).thenReturn(token); - request.setParameter(token.getParameterName(), token.getToken() + " INVALID"); - request.addHeader(token.getHeaderName(), token.getToken()); + when(this.requestMatcher.matches(this.request)).thenReturn(true); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); + this.request.setParameter(this.token.getParameterName(), + this.token.getToken() + " INVALID"); + this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertThat(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - verify(filterChain).doFilter(request, response); - verifyZeroInteractions(deniedHandler); + verify(this.filterChain).doFilter(this.request, this.response); + verifyZeroInteractions(this.deniedHandler); } @Test - public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { - when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.loadToken(request)).thenReturn(token); - request.setParameter(token.getParameterName(), token.getToken()); + public void doFilterIsCsrfRequestExistingToken() + throws ServletException, IOException { + when(this.requestMatcher.matches(this.request)).thenReturn(true); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); + this.request.setParameter(this.token.getParameterName(), this.token.getToken()); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertThat(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - verify(filterChain).doFilter(request, response); - verifyZeroInteractions(deniedHandler); + verify(this.filterChain).doFilter(this.request, this.response); + verifyZeroInteractions(this.deniedHandler); + verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), + any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test - public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { - when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.generateToken(request)).thenReturn(token); - request.setParameter(token.getParameterName(), token.getToken()); + public void doFilterIsCsrfRequestGenerateToken() + throws ServletException, IOException { + when(this.requestMatcher.matches(this.request)).thenReturn(true); + when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token); + this.request.setParameter(this.token.getParameterName(), this.token.getToken()); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertToken(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertToken(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertToken(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - verify(filterChain).doFilter(request, response); - verifyZeroInteractions(deniedHandler); + // LazyCsrfTokenRepository requires the response as an attribute + assertThat(this.request.getAttribute(HttpServletResponse.class.getName())) + .isEqualTo(this.response); + + verify(this.filterChain).doFilter(this.request, this.response); + verify(this.tokenRepository).saveToken(this.token, this.request, this.response); + verifyZeroInteractions(this.deniedHandler); } @Test public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { - filter = new CsrfFilter(tokenRepository); - filter.setAccessDeniedHandler(deniedHandler); + this.filter = new CsrfFilter(this.tokenRepository); + this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { resetRequestResponse(); - when(tokenRepository.loadToken(request)).thenReturn(token); - request.setMethod(method); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); + this.request.setMethod(method); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - verify(filterChain).doFilter(request, response); - verifyZeroInteractions(deniedHandler); + verify(this.filterChain).doFilter(this.request, this.response); + verifyZeroInteractions(this.deniedHandler); } } /** * SEC-2292 Should not allow other cases through since spec states HTTP method is case * sensitive http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.1 + * @throws Exception if an error occurs * - * @throws ServletException - * @throws IOException */ @Test public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() - throws ServletException, IOException { - filter = new CsrfFilter(tokenRepository); - filter.setAccessDeniedHandler(deniedHandler); + throws Exception { + this.filter = new CsrfFilter(this.tokenRepository); + this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) { resetRequestResponse(); - when(tokenRepository.loadToken(request)).thenReturn(token); - request.setMethod(method); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); + this.request.setMethod(method); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - verify(deniedHandler).handle(eq(request), eq(response), + verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); - verifyZeroInteractions(filterChain); + verifyZeroInteractions(this.filterChain); } } @Test public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { - filter = new CsrfFilter(tokenRepository); - filter.setAccessDeniedHandler(deniedHandler); + this.filter = new CsrfFilter(this.tokenRepository); + this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) { resetRequestResponse(); - when(tokenRepository.loadToken(request)).thenReturn(token); - request.setMethod(method); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); + this.request.setMethod(method); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - verify(deniedHandler).handle(eq(request), eq(response), + verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); - verifyZeroInteractions(filterChain); + verifyZeroInteractions(this.filterChain); } } @Test public void doFilterDefaultAccessDenied() throws ServletException, IOException { - filter = new CsrfFilter(tokenRepository); - filter.setRequireCsrfProtectionMatcher(requestMatcher); - when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.loadToken(request)).thenReturn(token); + this.filter = new CsrfFilter(this.tokenRepository); + this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); + when(this.requestMatcher.matches(this.request)).thenReturn(true); + when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(this.request.getAttribute(this.token.getParameterName())) + .isEqualTo(this.token); + assertThat(this.request.getAttribute(CsrfToken.class.getName())) + .isEqualTo(this.token); - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); - verifyZeroInteractions(filterChain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + verifyZeroInteractions(this.filterChain); } @Test(expected = IllegalArgumentException.class) public void setRequireCsrfProtectionMatcherNull() { - filter.setRequireCsrfProtectionMatcher(null); + this.filter.setRequireCsrfProtectionMatcher(null); } @Test(expected = IllegalArgumentException.class) public void setAccessDeniedHandlerNull() { - filter.setAccessDeniedHandler(null); + this.filter.setAccessDeniedHandler(null); } private static final CsrfTokenAssert assertToken(Object token) { return new CsrfTokenAssert((CsrfToken) token); } - private static class CsrfTokenAssert extends - AbstractObjectAssert { + private static class CsrfTokenAssert + extends AbstractObjectAssert { /** * Creates a new {@link ObjectAssert}. @@ -369,13 +413,14 @@ public class CsrfFilterTests { * @param actual the target to verify. */ protected CsrfTokenAssert(CsrfToken actual) { - super(actual,CsrfTokenAssert.class); + super(actual, CsrfTokenAssert.class); } public CsrfTokenAssert isEqualTo(CsrfToken expected) { - assertThat(actual.getHeaderName()).isEqualTo(expected.getHeaderName()); - assertThat(actual.getParameterName()).isEqualTo(expected.getParameterName()); - assertThat(actual.getToken()).isEqualTo(expected.getToken()); + assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName()); + assertThat(this.actual.getParameterName()) + .isEqualTo(expected.getParameterName()); + assertThat(this.actual.getToken()).isEqualTo(expected.getToken()); return this; } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java new file mode 100644 index 0000000000..e0d50666cf --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2012-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class LazyCsrfTokenRepositoryTests { + @Mock + CsrfTokenRepository delegate; + @Mock + HttpServletRequest request; + @Mock + HttpServletResponse response; + + @InjectMocks + LazyCsrfTokenRepository repository; + + DefaultCsrfToken token; + + @Before + public void setup() { + this.token = new DefaultCsrfToken("header", "param", "token"); + when(this.delegate.generateToken(this.request)).thenReturn(this.token); + when(this.request.getAttribute(HttpServletResponse.class.getName())) + .thenReturn(this.response); + } + + @Test(expected = IllegalArgumentException.class) + public void constructNullDelegateThrowsIllegalArgumentException() { + new LazyCsrfTokenRepository(null); + } + + @Test(expected = IllegalArgumentException.class) + public void generateTokenNullResponseAttribute() { + this.repository.generateToken(mock(HttpServletRequest.class)); + } + + @Test + public void generateTokenGetTokenSavesToken() { + CsrfToken newToken = this.repository.generateToken(this.request); + + newToken.getToken(); + + verify(this.delegate).saveToken(this.token, this.request, this.response); + } + + @Test + public void saveNonNullDoesNothing() { + this.repository.saveToken(this.token, this.request, this.response); + + verifyZeroInteractions(this.delegate); + } + + @Test + public void saveNullDelegates() { + this.repository.saveToken(null, this.request, this.response); + + verify(this.delegate).saveToken(null, this.request, this.response); + } + + @Test + public void loadTokenDelegates() { + when(this.delegate.loadToken(this.request)).thenReturn(this.token); + + CsrfToken loadToken = this.repository.loadToken(this.request); + assertThat(loadToken).isSameAs(this.token); + + verify(this.delegate).loadToken(this.request); + } +} \ No newline at end of file