Add CsrfTokenRepository (#3805)

* Create LazyCsrfTokenRepository

Fixes gh-3790

* Add CookieCsrfTokenRepository

Fixes gh-3009
This commit is contained in:
Rob Winch 2016-04-12 16:26:53 -05:00 committed by Joe Grandja
parent e9cb92bb74
commit d3a9cc6eae
14 changed files with 1461 additions and 857 deletions

View File

@ -33,6 +33,7 @@ import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler; import org.springframework.security.web.csrf.CsrfLogoutHandler;
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.security.web.session.InvalidSessionAccessDeniedHandler; import org.springframework.security.web.session.InvalidSessionAccessDeniedHandler;
import org.springframework.security.web.session.InvalidSessionStrategy; import org.springframework.security.web.session.InvalidSessionStrategy;
@ -43,8 +44,9 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
* Adds <a href="https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)" * Adds
* >CSRF</a> protection for the methods as specified by * <a href="https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)" >CSRF</a>
* protection for the methods as specified by
* {@link #requireCsrfProtectionMatcher(RequestMatcher)}. * {@link #requireCsrfProtectionMatcher(RequestMatcher)}.
* *
* <h2>Security Filters</h2> * <h2>Security Filters</h2>
@ -62,18 +64,18 @@ import org.springframework.util.Assert;
* <h2>Shared Objects Used</h2> * <h2>Shared Objects Used</h2>
* *
* <ul> * <ul>
* <li> * <li>{@link ExceptionHandlingConfigurer#accessDeniedHandler(AccessDeniedHandler)} is
* {@link ExceptionHandlingConfigurer#accessDeniedHandler(AccessDeniedHandler)} is used to * used to determine how to handle CSRF attempts</li>
* determine how to handle CSRF attempts</li>
* <li>{@link InvalidSessionStrategy}</li> * <li>{@link InvalidSessionStrategy}</li>
* </ul> * </ul>
* *
* @author Rob Winch * @author Rob Winch
* @since 3.2 * @since 3.2
*/ */
public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
AbstractHttpConfigurer<CsrfConfigurer<H>, H> { extends AbstractHttpConfigurer<CsrfConfigurer<H>, H> {
private CsrfTokenRepository csrfTokenRepository = new HttpSessionCsrfTokenRepository(); private CsrfTokenRepository csrfTokenRepository = new LazyCsrfTokenRepository(
new HttpSessionCsrfTokenRepository());
private RequestMatcher requireCsrfProtectionMatcher = CsrfFilter.DEFAULT_CSRF_MATCHER; private RequestMatcher requireCsrfProtectionMatcher = CsrfFilter.DEFAULT_CSRF_MATCHER;
private List<RequestMatcher> ignoredCsrfProtectionMatchers = new ArrayList<RequestMatcher>(); private List<RequestMatcher> ignoredCsrfProtectionMatchers = new ArrayList<RequestMatcher>();
@ -86,12 +88,13 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends
/** /**
* Specify the {@link CsrfTokenRepository} to use. The default is an * Specify the {@link CsrfTokenRepository} to use. The default is an
* {@link HttpSessionCsrfTokenRepository}. * {@link HttpSessionCsrfTokenRepository} wrapped by {@link LazyCsrfTokenRepository}.
* *
* @param csrfTokenRepository the {@link CsrfTokenRepository} to use * @param csrfTokenRepository the {@link CsrfTokenRepository} to use
* @return the {@link CsrfConfigurer} for further customizations * @return the {@link CsrfConfigurer} for further customizations
*/ */
public CsrfConfigurer<H> csrfTokenRepository(CsrfTokenRepository csrfTokenRepository) { public CsrfConfigurer<H> csrfTokenRepository(
CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
this.csrfTokenRepository = csrfTokenRepository; this.csrfTokenRepository = csrfTokenRepository;
return this; return this;
@ -144,7 +147,7 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override @Override
public void configure(H http) throws Exception { public void configure(H http) throws Exception {
CsrfFilter filter = new CsrfFilter(csrfTokenRepository); CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository);
RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher(); RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
if (requireCsrfProtectionMatcher != null) { if (requireCsrfProtectionMatcher != null) {
filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
@ -155,14 +158,14 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends
} }
LogoutConfigurer<H> logoutConfigurer = http.getConfigurer(LogoutConfigurer.class); LogoutConfigurer<H> logoutConfigurer = http.getConfigurer(LogoutConfigurer.class);
if (logoutConfigurer != null) { if (logoutConfigurer != null) {
logoutConfigurer.addLogoutHandler(new CsrfLogoutHandler(csrfTokenRepository)); logoutConfigurer
.addLogoutHandler(new CsrfLogoutHandler(this.csrfTokenRepository));
} }
SessionManagementConfigurer<H> sessionConfigurer = http SessionManagementConfigurer<H> sessionConfigurer = http
.getConfigurer(SessionManagementConfigurer.class); .getConfigurer(SessionManagementConfigurer.class);
if (sessionConfigurer != null) { if (sessionConfigurer != null) {
sessionConfigurer sessionConfigurer.addSessionAuthenticationStrategy(
.addSessionAuthenticationStrategy(new CsrfAuthenticationStrategy( new CsrfAuthenticationStrategy(this.csrfTokenRepository));
csrfTokenRepository));
} }
filter = postProcess(filter); filter = postProcess(filter);
http.addFilter(filter); http.addFilter(filter);
@ -175,12 +178,12 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends
* @return the {@link RequestMatcher} to use * @return the {@link RequestMatcher} to use
*/ */
private RequestMatcher getRequireCsrfProtectionMatcher() { private RequestMatcher getRequireCsrfProtectionMatcher() {
if (ignoredCsrfProtectionMatchers.isEmpty()) { if (this.ignoredCsrfProtectionMatchers.isEmpty()) {
return requireCsrfProtectionMatcher; return this.requireCsrfProtectionMatcher;
} }
return new AndRequestMatcher(requireCsrfProtectionMatcher, return new AndRequestMatcher(this.requireCsrfProtectionMatcher,
new NegatedRequestMatcher(new OrRequestMatcher( new NegatedRequestMatcher(
ignoredCsrfProtectionMatchers))); new OrRequestMatcher(this.ignoredCsrfProtectionMatchers)));
} }
/** /**
@ -238,7 +241,8 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends
*/ */
private AccessDeniedHandler createAccessDeniedHandler(H http) { private AccessDeniedHandler createAccessDeniedHandler(H http) {
InvalidSessionStrategy invalidSessionStrategy = getInvalidSessionStrategy(http); InvalidSessionStrategy invalidSessionStrategy = getInvalidSessionStrategy(http);
AccessDeniedHandler defaultAccessDeniedHandler = getDefaultAccessDeniedHandler(http); AccessDeniedHandler defaultAccessDeniedHandler = getDefaultAccessDeniedHandler(
http);
if (invalidSessionStrategy == null) { if (invalidSessionStrategy == null) {
return defaultAccessDeniedHandler; return defaultAccessDeniedHandler;
} }
@ -258,16 +262,17 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends
* @author Rob Winch * @author Rob Winch
* @since 4.0 * @since 4.0
*/ */
private class IgnoreCsrfProtectionRegistry extends private class IgnoreCsrfProtectionRegistry
AbstractRequestMatcherRegistry<IgnoreCsrfProtectionRegistry> { extends AbstractRequestMatcherRegistry<IgnoreCsrfProtectionRegistry> {
public CsrfConfigurer<H> and() { public CsrfConfigurer<H> and() {
return CsrfConfigurer.this; return CsrfConfigurer.this;
} }
@Override
protected IgnoreCsrfProtectionRegistry chainRequestMatchers( protected IgnoreCsrfProtectionRegistry chainRequestMatchers(
List<RequestMatcher> requestMatchers) { List<RequestMatcher> requestMatchers) {
ignoredCsrfProtectionMatchers.addAll(requestMatchers); CsrfConfigurer.this.ignoredCsrfProtectionMatchers.addAll(requestMatchers);
return this; return this;
} }
} }

View File

@ -15,6 +15,8 @@
*/ */
package org.springframework.security.config.http; package org.springframework.security.config.http;
import org.w3c.dom.Element;
import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.BeanMetadataElement;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.parsing.BeanComponentDefinition; import org.springframework.beans.factory.parsing.BeanComponentDefinition;
@ -31,13 +33,13 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler; import org.springframework.security.web.csrf.CsrfLogoutHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor; import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor;
import org.springframework.security.web.session.InvalidSessionAccessDeniedHandler; import org.springframework.security.web.session.InvalidSessionAccessDeniedHandler;
import org.springframework.security.web.session.InvalidSessionStrategy; import org.springframework.security.web.session.InvalidSessionStrategy;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.w3c.dom.Element;
/** /**
* Parser for the {@code CsrfFilter}. * Parser for the {@code CsrfFilter}.
@ -55,6 +57,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
private String csrfRepositoryRef; private String csrfRepositoryRef;
private BeanDefinition csrfFilter; private BeanDefinition csrfFilter;
@Override
public BeanDefinition parse(Element element, ParserContext pc) { public BeanDefinition parse(Element element, ParserContext pc) {
boolean disabled = element != null boolean disabled = element != null
&& "true".equals(element.getAttribute("disabled")); && "true".equals(element.getAttribute("disabled"));
@ -73,29 +76,33 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
String matcherRef = null; String matcherRef = null;
if (element != null) { if (element != null) {
csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY); this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY);
matcherRef = element.getAttribute(ATT_MATCHER); matcherRef = element.getAttribute(ATT_MATCHER);
} }
if (!StringUtils.hasText(csrfRepositoryRef)) { if (!StringUtils.hasText(this.csrfRepositoryRef)) {
RootBeanDefinition csrfTokenRepository = new RootBeanDefinition( RootBeanDefinition csrfTokenRepository = new RootBeanDefinition(
HttpSessionCsrfTokenRepository.class); HttpSessionCsrfTokenRepository.class);
csrfRepositoryRef = pc.getReaderContext().generateBeanName( BeanDefinitionBuilder lazyTokenRepository = BeanDefinitionBuilder
csrfTokenRepository); .rootBeanDefinition(LazyCsrfTokenRepository.class);
pc.registerBeanComponent(new BeanComponentDefinition(csrfTokenRepository, lazyTokenRepository.addConstructorArgValue(csrfTokenRepository);
csrfRepositoryRef)); this.csrfRepositoryRef = pc.getReaderContext()
.generateBeanName(lazyTokenRepository.getBeanDefinition());
pc.registerBeanComponent(new BeanComponentDefinition(
lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef));
} }
BeanDefinitionBuilder builder = BeanDefinitionBuilder BeanDefinitionBuilder builder = BeanDefinitionBuilder
.rootBeanDefinition(CsrfFilter.class); .rootBeanDefinition(CsrfFilter.class);
builder.addConstructorArgReference(csrfRepositoryRef); builder.addConstructorArgReference(this.csrfRepositoryRef);
if (StringUtils.hasText(matcherRef)) { if (StringUtils.hasText(matcherRef)) {
builder.addPropertyReference("requireCsrfProtectionMatcher", matcherRef); builder.addPropertyReference("requireCsrfProtectionMatcher", matcherRef);
} }
csrfFilter = builder.getBeanDefinition(); this.csrfFilter = builder.getBeanDefinition();
return csrfFilter; return this.csrfFilter;
} }
/** /**
@ -108,7 +115,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
BeanMetadataElement defaultDeniedHandler) { BeanMetadataElement defaultDeniedHandler) {
BeanMetadataElement accessDeniedHandler = createAccessDeniedHandler( BeanMetadataElement accessDeniedHandler = createAccessDeniedHandler(
invalidSessionStrategy, defaultDeniedHandler); invalidSessionStrategy, defaultDeniedHandler);
csrfFilter.getPropertyValues().addPropertyValue("accessDeniedHandler", this.csrfFilter.getPropertyValues().addPropertyValue("accessDeniedHandler",
accessDeniedHandler); accessDeniedHandler);
} }
@ -152,14 +159,14 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
BeanDefinition getCsrfAuthenticationStrategy() { BeanDefinition getCsrfAuthenticationStrategy() {
BeanDefinitionBuilder csrfAuthenticationStrategy = BeanDefinitionBuilder BeanDefinitionBuilder csrfAuthenticationStrategy = BeanDefinitionBuilder
.rootBeanDefinition(CsrfAuthenticationStrategy.class); .rootBeanDefinition(CsrfAuthenticationStrategy.class);
csrfAuthenticationStrategy.addConstructorArgReference(csrfRepositoryRef); csrfAuthenticationStrategy.addConstructorArgReference(this.csrfRepositoryRef);
return csrfAuthenticationStrategy.getBeanDefinition(); return csrfAuthenticationStrategy.getBeanDefinition();
} }
BeanDefinition getCsrfLogoutHandler() { BeanDefinition getCsrfLogoutHandler() {
BeanDefinitionBuilder csrfAuthenticationStrategy = BeanDefinitionBuilder BeanDefinitionBuilder csrfAuthenticationStrategy = BeanDefinitionBuilder
.rootBeanDefinition(CsrfLogoutHandler.class); .rootBeanDefinition(CsrfLogoutHandler.class);
csrfAuthenticationStrategy.addConstructorArgReference(csrfRepositoryRef); csrfAuthenticationStrategy.addConstructorArgReference(this.csrfRepositoryRef);
return csrfAuthenticationStrategy.getBeanDefinition(); return csrfAuthenticationStrategy.getBeanDefinition();
} }
} }

View File

@ -15,11 +15,10 @@
*/ */
package org.springframework.security.config.annotation.web.configurers package org.springframework.security.config.annotation.web.configurers
import org.springframework.security.web.util.matcher.AntPathRequestMatcher
import javax.servlet.http.HttpServletResponse import javax.servlet.http.HttpServletResponse
import org.springframework.context.annotation.Configuration import spock.lang.Unroll
import org.springframework.mock.web.MockHttpServletRequest import org.springframework.mock.web.MockHttpServletRequest
import org.springframework.mock.web.MockHttpServletResponse import org.springframework.mock.web.MockHttpServletResponse
import org.springframework.security.config.annotation.BaseSpringSpec import org.springframework.security.config.annotation.BaseSpringSpec
@ -27,15 +26,13 @@ import org.springframework.security.config.annotation.authentication.builders.Au
import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
import org.springframework.security.config.annotation.web.servlet.configuration.EnableWebMvcSecurity;
import org.springframework.security.web.access.AccessDeniedHandler import org.springframework.security.web.access.AccessDeniedHandler
import org.springframework.security.web.csrf.CsrfFilter import org.springframework.security.web.csrf.CsrfFilter
import org.springframework.security.web.csrf.CsrfTokenRepository import org.springframework.security.web.csrf.CsrfTokenRepository
import org.springframework.security.web.util.matcher.AntPathRequestMatcher
import org.springframework.security.web.util.matcher.RequestMatcher import org.springframework.security.web.util.matcher.RequestMatcher
import org.springframework.web.servlet.support.RequestDataValueProcessor import org.springframework.web.servlet.support.RequestDataValueProcessor
import spock.lang.Unroll
/** /**
* *
* @author Rob Winch * @author Rob Winch
@ -45,31 +42,31 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Unroll @Unroll
def "csrf applied by default"() { def "csrf applied by default"() {
setup: setup:
loadConfig(CsrfAppliedDefaultConfig) loadConfig(CsrfAppliedDefaultConfig)
request.method = httpMethod request.method = httpMethod
clearCsrfToken() clearCsrfToken()
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.status == httpStatus response.status == httpStatus
where: where:
httpMethod | httpStatus httpMethod | httpStatus
'POST' | HttpServletResponse.SC_FORBIDDEN 'POST' | HttpServletResponse.SC_FORBIDDEN
'PUT' | HttpServletResponse.SC_FORBIDDEN 'PUT' | HttpServletResponse.SC_FORBIDDEN
'PATCH' | HttpServletResponse.SC_FORBIDDEN 'PATCH' | HttpServletResponse.SC_FORBIDDEN
'DELETE' | HttpServletResponse.SC_FORBIDDEN 'DELETE' | HttpServletResponse.SC_FORBIDDEN
'INVALID' | HttpServletResponse.SC_FORBIDDEN 'INVALID' | HttpServletResponse.SC_FORBIDDEN
'GET' | HttpServletResponse.SC_OK 'GET' | HttpServletResponse.SC_OK
'HEAD' | HttpServletResponse.SC_OK 'HEAD' | HttpServletResponse.SC_OK
'TRACE' | HttpServletResponse.SC_OK 'TRACE' | HttpServletResponse.SC_OK
'OPTIONS' | HttpServletResponse.SC_OK 'OPTIONS' | HttpServletResponse.SC_OK
} }
def "csrf default creates CsrfRequestDataValueProcessor"() { def "csrf default creates CsrfRequestDataValueProcessor"() {
when: when:
loadConfig(CsrfAppliedDefaultConfig) loadConfig(CsrfAppliedDefaultConfig)
then: then:
context.getBean(RequestDataValueProcessor) context.getBean(RequestDataValueProcessor)
} }
@EnableWebSecurity @EnableWebSecurity
@ -82,14 +79,14 @@ class CsrfConfigurerTests extends BaseSpringSpec {
def "csrf disable"() { def "csrf disable"() {
setup: setup:
loadConfig(DisableCsrfConfig) loadConfig(DisableCsrfConfig)
request.method = "POST" request.method = "POST"
clearCsrfToken() clearCsrfToken()
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
!findFilter(CsrfFilter) !findFilter(CsrfFilter)
response.status == HttpServletResponse.SC_OK response.status == HttpServletResponse.SC_OK
} }
@EnableWebSecurity @EnableWebSecurity
@ -98,29 +95,29 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.csrf().disable() .csrf().disable()
} }
} }
def "SEC-2498: Disable CSRF enables RequestCache for any method"() { def "SEC-2498: Disable CSRF enables RequestCache for any method"() {
setup: setup:
loadConfig(DisableCsrfEnablesRequestCacheConfig) loadConfig(DisableCsrfEnablesRequestCacheConfig)
request.requestURI = '/tosave' request.requestURI = '/tosave'
request.method = "POST" request.method = "POST"
clearCsrfToken() clearCsrfToken()
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.redirectedUrl response.redirectedUrl
when: when:
super.setupWeb(request.session) super.setupWeb(request.session)
request.method = "POST" request.method = "POST"
request.servletPath = '/login' request.servletPath = '/login'
request.parameters['username'] = ['user'] as String[] request.parameters['username'] = ['user'] as String[]
request.parameters['password'] = ['password'] as String[] request.parameters['password'] = ['password'] as String[]
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.redirectedUrl == 'http://localhost/tosave' response.redirectedUrl == 'http://localhost/tosave'
} }
@EnableWebSecurity @EnableWebSecurity
@ -129,38 +126,37 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.authorizeRequests() .authorizeRequests()
.anyRequest().authenticated() .anyRequest().authenticated()
.and() .and()
.formLogin().and() .formLogin().and()
.csrf().disable() .csrf().disable()
} }
@Override @Override
protected void configure(AuthenticationManagerBuilder auth) throws Exception { protected void configure(AuthenticationManagerBuilder auth) throws Exception {
auth auth
.inMemoryAuthentication() .inMemoryAuthentication()
.withUser("user").password("password").roles("USER") .withUser("user").password("password").roles("USER")
} }
} }
def "SEC-2422: csrf expire CSRF token and session-management invalid-session-url"() { def "SEC-2422: csrf expire CSRF token and session-management invalid-session-url"() {
setup: setup:
loadConfig(InvalidSessionUrlConfig) loadConfig(InvalidSessionUrlConfig)
request.session.clearAttributes() request.session.clearAttributes()
request.setParameter("_csrf","abc") request.setParameter("_csrf","abc")
request.method = "POST" request.method = "POST"
when: "No existing expected CsrfToken (session times out) and a POST" when: "No existing expected CsrfToken (session times out) and a POST"
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "sent to the session timeout page page" then: "sent to the session timeout page page"
response.status == HttpServletResponse.SC_MOVED_TEMPORARILY response.status == HttpServletResponse.SC_MOVED_TEMPORARILY
response.redirectedUrl == "/error/sessionError" response.redirectedUrl == "/error/sessionError"
when: "Existing expected CsrfToken and a POST (invalid token provided)" when: "Existing expected CsrfToken and a POST (invalid token provided)"
response = new MockHttpServletResponse() response = new MockHttpServletResponse()
request = new MockHttpServletRequest(session: request.session, method:'POST') request = new MockHttpServletRequest(session: request.session, method:'POST')
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "Access Denied occurs" then: "Access Denied occurs"
response.status == HttpServletResponse.SC_FORBIDDEN response.status == HttpServletResponse.SC_FORBIDDEN
} }
@EnableWebSecurity @EnableWebSecurity
@ -168,26 +164,26 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.csrf().and() .csrf().and()
.sessionManagement() .sessionManagement()
.invalidSessionUrl("/error/sessionError") .invalidSessionUrl("/error/sessionError")
} }
} }
def "csrf requireCsrfProtectionMatcher"() { def "csrf requireCsrfProtectionMatcher"() {
setup: setup:
RequireCsrfProtectionMatcherConfig.matcher = Mock(RequestMatcher) RequireCsrfProtectionMatcherConfig.matcher = Mock(RequestMatcher)
RequireCsrfProtectionMatcherConfig.matcher.matches(_) >>> [false,true] RequireCsrfProtectionMatcherConfig.matcher.matches(_) >>> [false, true]
loadConfig(RequireCsrfProtectionMatcherConfig) loadConfig(RequireCsrfProtectionMatcherConfig)
clearCsrfToken() clearCsrfToken()
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.status == HttpServletResponse.SC_OK response.status == HttpServletResponse.SC_OK
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.status == HttpServletResponse.SC_FORBIDDEN response.status == HttpServletResponse.SC_FORBIDDEN
} }
@EnableWebSecurity @EnableWebSecurity
@ -197,53 +193,53 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.csrf() .csrf()
.requireCsrfProtectionMatcher(matcher) .requireCsrfProtectionMatcher(matcher)
} }
} }
def "csrf csrfTokenRepository"() { def "csrf csrfTokenRepository"() {
setup: setup:
CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository) CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository)
loadConfig(CsrfTokenRepositoryConfig) loadConfig(CsrfTokenRepositoryConfig)
clearCsrfToken() clearCsrfToken()
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
1 * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken 1 * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken
response.status == HttpServletResponse.SC_OK response.status == HttpServletResponse.SC_OK
} }
def "csrf clears on logout"() { def "csrf clears on logout"() {
setup: setup:
CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository) CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository)
1 * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken 1 * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken
loadConfig(CsrfTokenRepositoryConfig) loadConfig(CsrfTokenRepositoryConfig)
login() login()
request.method = "POST" request.method = "POST"
request.servletPath = "/logout" request.servletPath = "/logout"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
1 * CsrfTokenRepositoryConfig.repo.saveToken(null, _, _) 1 * CsrfTokenRepositoryConfig.repo.saveToken(null, _, _)
} }
def "csrf clears on login"() { def "csrf clears on login"() {
setup: setup:
CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository) CsrfTokenRepositoryConfig.repo = Mock(CsrfTokenRepository)
(1.._) * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken (1.._) * CsrfTokenRepositoryConfig.repo.loadToken(_) >> csrfToken
(1.._) * CsrfTokenRepositoryConfig.repo.generateToken(_) >> csrfToken (1.._) * CsrfTokenRepositoryConfig.repo.generateToken(_) >> csrfToken
loadConfig(CsrfTokenRepositoryConfig) loadConfig(CsrfTokenRepositoryConfig)
request.method = "POST" request.method = "POST"
request.getSession() request.getSession()
request.servletPath = "/login" request.servletPath = "/login"
request.setParameter("username", "user") request.setParameter("username", "user")
request.setParameter("password", "password") request.setParameter("password", "password")
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.redirectedUrl == "/" response.redirectedUrl == "/"
(1.._) * CsrfTokenRepositoryConfig.repo.saveToken(null, _, _) (1.._) * CsrfTokenRepositoryConfig.repo.saveToken(null, _, _)
} }
@EnableWebSecurity @EnableWebSecurity
@ -253,30 +249,30 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.formLogin() .formLogin()
.and() .and()
.csrf() .csrf()
.csrfTokenRepository(repo) .csrfTokenRepository(repo)
} }
@Override @Override
protected void configure(AuthenticationManagerBuilder auth) throws Exception { protected void configure(AuthenticationManagerBuilder auth) throws Exception {
auth auth
.inMemoryAuthentication() .inMemoryAuthentication()
.withUser("user").password("password").roles("USER") .withUser("user").password("password").roles("USER")
} }
} }
def "csrf access denied handler"() { def "csrf access denied handler"() {
setup: setup:
AccessDeniedHandlerConfig.deniedHandler = Mock(AccessDeniedHandler) AccessDeniedHandlerConfig.deniedHandler = Mock(AccessDeniedHandler)
1 * AccessDeniedHandlerConfig.deniedHandler.handle(_, _, _) 1 * AccessDeniedHandlerConfig.deniedHandler.handle(_, _, _)
loadConfig(AccessDeniedHandlerConfig) loadConfig(AccessDeniedHandlerConfig)
clearCsrfToken() clearCsrfToken()
request.method = "POST" request.method = "POST"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.status == HttpServletResponse.SC_OK response.status == HttpServletResponse.SC_OK
} }
@EnableWebSecurity @EnableWebSecurity
@ -286,24 +282,24 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.exceptionHandling() .exceptionHandling()
.accessDeniedHandler(deniedHandler) .accessDeniedHandler(deniedHandler)
} }
} }
def "formLogin requires CSRF token"() { def "formLogin requires CSRF token"() {
setup: setup:
loadConfig(FormLoginConfig) loadConfig(FormLoginConfig)
clearCsrfToken() clearCsrfToken()
request.setParameter("username", "user") request.setParameter("username", "user")
request.setParameter("password", "password") request.setParameter("password", "password")
request.servletPath = "/login" request.servletPath = "/login"
request.method = "POST" request.method = "POST"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.status == HttpServletResponse.SC_FORBIDDEN response.status == HttpServletResponse.SC_FORBIDDEN
currentAuthentication == null currentAuthentication == null
} }
@EnableWebSecurity @EnableWebSecurity
@ -313,34 +309,34 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.formLogin() .formLogin()
} }
} }
def "logout requires CSRF token"() { def "logout requires CSRF token"() {
setup: setup:
loadConfig(LogoutConfig) loadConfig(LogoutConfig)
clearCsrfToken() clearCsrfToken()
login() login()
request.servletPath = "/logout" request.servletPath = "/logout"
request.method = "POST" request.method = "POST"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "logout is not allowed and user is still authenticated" then: "logout is not allowed and user is still authenticated"
response.status == HttpServletResponse.SC_FORBIDDEN response.status == HttpServletResponse.SC_FORBIDDEN
currentAuthentication != null currentAuthentication != null
} }
def "SEC-2543: CSRF means logout requires POST"() { def "SEC-2543: CSRF means logout requires POST"() {
setup: setup:
loadConfig(LogoutConfig) loadConfig(LogoutConfig)
login() login()
request.servletPath = "/logout" request.servletPath = "/logout"
request.method = "GET" request.method = "GET"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "logout with GET is not performed" then: "logout with GET is not performed"
currentAuthentication != null currentAuthentication != null
} }
@EnableWebSecurity @EnableWebSecurity
@ -350,20 +346,20 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.formLogin() .formLogin()
} }
} }
def "CSRF can explicitly enable GET for logout"() { def "CSRF can explicitly enable GET for logout"() {
setup: setup:
loadConfig(LogoutAllowsGetConfig) loadConfig(LogoutAllowsGetConfig)
login() login()
request.servletPath = "/logout" request.servletPath = "/logout"
request.method = "GET" request.method = "GET"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "logout with GET is not performed" then: "logout with GET is not performed"
currentAuthentication == null currentAuthentication == null
} }
@EnableWebSecurity @EnableWebSecurity
@ -373,64 +369,64 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.formLogin().and() .formLogin().and()
.logout() .logout()
.logoutRequestMatcher(new AntPathRequestMatcher("/logout")) .logoutRequestMatcher(new AntPathRequestMatcher("/logout"))
} }
} }
def "csrf disables POST requests from RequestCache"() { def "csrf disables POST requests from RequestCache"() {
setup: setup:
CsrfDisablesPostRequestFromRequestCacheConfig.repo = Mock(CsrfTokenRepository) CsrfDisablesPostRequestFromRequestCacheConfig.repo = Mock(CsrfTokenRepository)
(1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.generateToken(_) >> csrfToken (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.generateToken(_) >> csrfToken
loadConfig(CsrfDisablesPostRequestFromRequestCacheConfig) loadConfig(CsrfDisablesPostRequestFromRequestCacheConfig)
request.servletPath = "/some-url" request.servletPath = "/some-url"
request.requestURI = "/some-url" request.requestURI = "/some-url"
request.method = "POST" request.method = "POST"
when: "CSRF passes and our session times out" when: "CSRF passes and our session times out"
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "sent to the login page" then: "sent to the login page"
(1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken
response.status == HttpServletResponse.SC_MOVED_TEMPORARILY response.status == HttpServletResponse.SC_MOVED_TEMPORARILY
response.redirectedUrl == "http://localhost/login" response.redirectedUrl == "http://localhost/login"
when: "authenticate successfully" when: "authenticate successfully"
super.setupWeb(request.session) super.setupWeb(request.session)
request.servletPath = "/login" request.servletPath = "/login"
request.setParameter("username","user") request.setParameter("username","user")
request.setParameter("password","password") request.setParameter("password","password")
request.method = "POST" request.method = "POST"
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "sent to default success because we don't want csrf attempts made prior to authentication to pass" then: "sent to default success because we don't want csrf attempts made prior to authentication to pass"
(1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken
response.status == HttpServletResponse.SC_MOVED_TEMPORARILY response.status == HttpServletResponse.SC_MOVED_TEMPORARILY
response.redirectedUrl == "/" response.redirectedUrl == "/"
} }
def "csrf enables GET requests with RequestCache"() { def "csrf enables GET requests with RequestCache"() {
setup: setup:
CsrfDisablesPostRequestFromRequestCacheConfig.repo = Mock(CsrfTokenRepository) CsrfDisablesPostRequestFromRequestCacheConfig.repo = Mock(CsrfTokenRepository)
(1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.generateToken(_) >> csrfToken (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.generateToken(_) >> csrfToken
loadConfig(CsrfDisablesPostRequestFromRequestCacheConfig) loadConfig(CsrfDisablesPostRequestFromRequestCacheConfig)
request.servletPath = "/some-url" request.servletPath = "/some-url"
request.requestURI = "/some-url" request.requestURI = "/some-url"
request.method = "GET" request.method = "GET"
when: "CSRF passes and our session times out" when: "CSRF passes and our session times out"
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "sent to the login page" then: "sent to the login page"
(1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken
response.status == HttpServletResponse.SC_MOVED_TEMPORARILY response.status == HttpServletResponse.SC_MOVED_TEMPORARILY
response.redirectedUrl == "http://localhost/login" response.redirectedUrl == "http://localhost/login"
when: "authenticate successfully" when: "authenticate successfully"
super.setupWeb(request.session) super.setupWeb(request.session)
request.servletPath = "/login" request.servletPath = "/login"
request.setParameter("username","user") request.setParameter("username","user")
request.setParameter("password","password") request.setParameter("password","password")
request.method = "POST" request.method = "POST"
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "sent to original URL since it was a GET" then: "sent to original URL since it was a GET"
(1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken (1.._) * CsrfDisablesPostRequestFromRequestCacheConfig.repo.loadToken(_) >> csrfToken
response.status == HttpServletResponse.SC_MOVED_TEMPORARILY response.status == HttpServletResponse.SC_MOVED_TEMPORARILY
response.redirectedUrl == "http://localhost/some-url" response.redirectedUrl == "http://localhost/some-url"
} }
@EnableWebSecurity @EnableWebSecurity
@ -440,18 +436,18 @@ class CsrfConfigurerTests extends BaseSpringSpec {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.authorizeRequests() .authorizeRequests()
.anyRequest().authenticated() .anyRequest().authenticated()
.and() .and()
.formLogin() .formLogin()
.and() .and()
.csrf() .csrf()
.csrfTokenRepository(repo) .csrfTokenRepository(repo)
} }
@Override @Override
protected void configure(AuthenticationManagerBuilder auth) throws Exception { protected void configure(AuthenticationManagerBuilder auth) throws Exception {
auth auth
.inMemoryAuthentication() .inMemoryAuthentication()
.withUser("user").password("password").roles("USER") .withUser("user").password("password").roles("USER")
} }
} }
@ -463,6 +459,39 @@ class CsrfConfigurerTests extends BaseSpringSpec {
thrown(IllegalArgumentException) thrown(IllegalArgumentException)
} }
def 'default does not create session'() {
setup:
request = new MockHttpServletRequest(method:'GET')
loadConfig(DefaultDoesNotCreateSession)
when:
springSecurityFilterChain.doFilter(request,response,chain)
then:
request.getSession(false) == null
}
@EnableWebSecurity(debug=true)
static class DefaultDoesNotCreateSession extends WebSecurityConfigurerAdapter {
@Override
protected void configure(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests()
.anyRequest().permitAll()
.and()
.formLogin().and()
.httpBasic();
// @formatter:on
}
@Override
protected void configure(AuthenticationManagerBuilder auth) throws Exception {
auth
.inMemoryAuthentication()
.withUser("user").password("password").roles("USER")
}
}
def clearCsrfToken() { def clearCsrfToken() {
request.removeAllParameters() request.removeAllParameters()
} }

View File

@ -12,12 +12,11 @@
*/ */
package org.springframework.security.config.http package org.springframework.security.config.http
import static org.mockito.Matchers.*
import static org.mockito.Mockito.*
import javax.servlet.http.HttpServletRequest import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse import javax.servlet.http.HttpServletResponse
import spock.lang.Unroll
import org.springframework.mock.web.MockFilterChain import org.springframework.mock.web.MockFilterChain
import org.springframework.mock.web.MockHttpServletRequest import org.springframework.mock.web.MockHttpServletRequest
import org.springframework.mock.web.MockHttpServletResponse import org.springframework.mock.web.MockHttpServletResponse
@ -36,7 +35,8 @@ import org.springframework.security.web.csrf.DefaultCsrfToken
import org.springframework.security.web.util.matcher.RequestMatcher import org.springframework.security.web.util.matcher.RequestMatcher
import org.springframework.web.servlet.support.RequestDataValueProcessor import org.springframework.web.servlet.support.RequestDataValueProcessor
import spock.lang.Unroll import static org.mockito.Matchers.*
import static org.mockito.Mockito.*
/** /**
* *
@ -73,256 +73,253 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
def 'csrf disabled'() { def 'csrf disabled'() {
when: when:
httpAutoConfig { httpAutoConfig { csrf(disabled:true) }
csrf(disabled:true) createAppContext()
}
createAppContext()
then: then:
!getFilter(CsrfFilter) !getFilter(CsrfFilter)
} }
@Unroll @Unroll
def 'csrf defaults'() { def 'csrf defaults'() {
setup: setup:
httpAutoConfig { httpAutoConfig { 'csrf'() }
'csrf'() createAppContext()
}
createAppContext()
when: when:
request.method = httpMethod request.method = httpMethod
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.status == httpStatus response.status == httpStatus
where: where:
httpMethod | httpStatus httpMethod | httpStatus
'POST' | HttpServletResponse.SC_FORBIDDEN 'POST' | HttpServletResponse.SC_FORBIDDEN
'PUT' | HttpServletResponse.SC_FORBIDDEN 'PUT' | HttpServletResponse.SC_FORBIDDEN
'PATCH' | HttpServletResponse.SC_FORBIDDEN 'PATCH' | HttpServletResponse.SC_FORBIDDEN
'DELETE' | HttpServletResponse.SC_FORBIDDEN 'DELETE' | HttpServletResponse.SC_FORBIDDEN
'INVALID' | HttpServletResponse.SC_FORBIDDEN 'INVALID' | HttpServletResponse.SC_FORBIDDEN
'GET' | HttpServletResponse.SC_OK 'GET' | HttpServletResponse.SC_OK
'HEAD' | HttpServletResponse.SC_OK 'HEAD' | HttpServletResponse.SC_OK
'TRACE' | HttpServletResponse.SC_OK 'TRACE' | HttpServletResponse.SC_OK
'OPTIONS' | HttpServletResponse.SC_OK 'OPTIONS' | HttpServletResponse.SC_OK
} }
def 'csrf default creates CsrfRequestDataValueProcessor'() { def 'csrf default creates CsrfRequestDataValueProcessor'() {
when: when:
httpAutoConfig { httpAutoConfig { 'csrf'() }
'csrf'() createAppContext()
}
createAppContext()
then: then:
appContext.getBean("requestDataValueProcessor",RequestDataValueProcessor) appContext.getBean("requestDataValueProcessor",RequestDataValueProcessor)
} }
def 'csrf custom AccessDeniedHandler'() { def 'csrf custom AccessDeniedHandler'() {
setup: setup:
httpAutoConfig { httpAutoConfig {
'access-denied-handler'(ref:'adh') 'access-denied-handler'(ref:'adh')
'csrf'() 'csrf'()
} }
mockBean(AccessDeniedHandler,'adh') mockBean(AccessDeniedHandler,'adh')
createAppContext() createAppContext()
AccessDeniedHandler adh = appContext.getBean(AccessDeniedHandler) AccessDeniedHandler adh = appContext.getBean(AccessDeniedHandler)
request.method = "POST" request.method = "POST"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
verify(adh).handle(any(HttpServletRequest),any(HttpServletResponse),any(AccessDeniedException)) verify(adh).handle(any(HttpServletRequest),any(HttpServletResponse),any(AccessDeniedException))
response.status == HttpServletResponse.SC_OK // our mock doesn't do anything response.status == HttpServletResponse.SC_OK // our mock doesn't do anything
} }
def "csrf disables posts for RequestCache"() { def "csrf disables posts for RequestCache"() {
setup: setup:
httpAutoConfig { httpAutoConfig {
'csrf'('token-repository-ref':'repo') 'csrf'('token-repository-ref':'repo')
'intercept-url'(pattern:"/**",access:'ROLE_USER') 'intercept-url'(pattern:"/**",access:'ROLE_USER')
} }
mockBean(CsrfTokenRepository,'repo') mockBean(CsrfTokenRepository,'repo')
createAppContext() createAppContext()
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
when(repo.generateToken(any(HttpServletRequest))).thenReturn(token) when(repo.generateToken(any(HttpServletRequest))).thenReturn(token)
request.setParameter(token.parameterName,token.token) request.setParameter(token.parameterName,token.token)
request.servletPath = "/some-url" request.servletPath = "/some-url"
request.requestURI = "/some-url" request.requestURI = "/some-url"
request.method = "POST" request.method = "POST"
when: "CSRF passes and our session times out" when: "CSRF passes and our session times out"
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "sent to the login page" then: "sent to the login page"
response.status == HttpServletResponse.SC_MOVED_TEMPORARILY response.status == HttpServletResponse.SC_MOVED_TEMPORARILY
response.redirectedUrl == "http://localhost/login" response.redirectedUrl == "http://localhost/login"
when: "authenticate successfully" when: "authenticate successfully"
response = new MockHttpServletResponse() response = new MockHttpServletResponse()
request = new MockHttpServletRequest(session: request.session) request = new MockHttpServletRequest(session: request.session)
request.servletPath = "/login" request.servletPath = "/login"
request.setParameter(token.parameterName,token.token) request.setParameter(token.parameterName,token.token)
request.setParameter("username","user") request.setParameter("username","user")
request.setParameter("password","password") request.setParameter("password","password")
request.method = "POST" request.method = "POST"
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "sent to default success because we don't want csrf attempts made prior to authentication to pass" then: "sent to default success because we don't want csrf attempts made prior to authentication to pass"
response.status == HttpServletResponse.SC_MOVED_TEMPORARILY response.status == HttpServletResponse.SC_MOVED_TEMPORARILY
response.redirectedUrl == "/" response.redirectedUrl == "/"
} }
def "csrf enables gets for RequestCache"() { def "csrf enables gets for RequestCache"() {
setup: setup:
httpAutoConfig { httpAutoConfig {
'csrf'('token-repository-ref':'repo') 'csrf'('token-repository-ref':'repo')
'intercept-url'(pattern:"/**",access:'ROLE_USER') 'intercept-url'(pattern:"/**",access:'ROLE_USER')
} }
mockBean(CsrfTokenRepository,'repo') mockBean(CsrfTokenRepository,'repo')
createAppContext() createAppContext()
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
when(repo.generateToken(any(HttpServletRequest))).thenReturn(token) when(repo.generateToken(any(HttpServletRequest))).thenReturn(token)
request.setParameter(token.parameterName,token.token) request.setParameter(token.parameterName,token.token)
request.servletPath = "/some-url" request.servletPath = "/some-url"
request.requestURI = "/some-url" request.requestURI = "/some-url"
request.method = "GET" request.method = "GET"
when: "CSRF passes and our session times out" when: "CSRF passes and our session times out"
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "sent to the login page" then: "sent to the login page"
response.status == HttpServletResponse.SC_MOVED_TEMPORARILY response.status == HttpServletResponse.SC_MOVED_TEMPORARILY
response.redirectedUrl == "http://localhost/login" response.redirectedUrl == "http://localhost/login"
when: "authenticate successfully" when: "authenticate successfully"
response = new MockHttpServletResponse() response = new MockHttpServletResponse()
request = new MockHttpServletRequest(session: request.session) request = new MockHttpServletRequest(session: request.session)
request.servletPath = "/login" request.servletPath = "/login"
request.setParameter(token.parameterName,token.token) request.setParameter(token.parameterName,token.token)
request.setParameter("username","user") request.setParameter("username","user")
request.setParameter("password","password") request.setParameter("password","password")
request.method = "POST" request.method = "POST"
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "sent to original URL since it was a GET" then: "sent to original URL since it was a GET"
response.status == HttpServletResponse.SC_MOVED_TEMPORARILY response.status == HttpServletResponse.SC_MOVED_TEMPORARILY
response.redirectedUrl == "http://localhost/some-url" response.redirectedUrl == "http://localhost/some-url"
} }
def "SEC-2422: csrf expire CSRF token and session-management invalid-session-url"() { def "SEC-2422: csrf expire CSRF token and session-management invalid-session-url"() {
setup: setup:
httpAutoConfig { httpAutoConfig {
'csrf'() 'csrf'()
'session-management'('invalid-session-url': '/error/sessionError') 'session-management'('invalid-session-url': '/error/sessionError')
} }
createAppContext() createAppContext()
request.setParameter("_csrf","abc") request.setParameter("_csrf","abc")
request.method = "POST" request.method = "POST"
when: "No existing expected CsrfToken (session times out) and a POST" when: "No existing expected CsrfToken (session times out) and a POST"
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "sent to the session timeout page page" then: "sent to the session timeout page page"
response.status == HttpServletResponse.SC_MOVED_TEMPORARILY response.status == HttpServletResponse.SC_MOVED_TEMPORARILY
response.redirectedUrl == "/error/sessionError" response.redirectedUrl == "/error/sessionError"
when: "Existing expected CsrfToken and a POST (invalid token provided)" when: "Existing expected CsrfToken and a POST (invalid token provided)"
response = new MockHttpServletResponse() response = new MockHttpServletResponse()
request = new MockHttpServletRequest(session: request.session, method:'POST') request = new MockHttpServletRequest(session: request.session, method:'POST')
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: "Access Denied occurs" then: "Access Denied occurs"
response.status == HttpServletResponse.SC_FORBIDDEN response.status == HttpServletResponse.SC_FORBIDDEN
} }
def "csrf requireCsrfProtectionMatcher"() { def "csrf requireCsrfProtectionMatcher"() {
setup: setup:
httpAutoConfig { httpAutoConfig { 'csrf'('request-matcher-ref':'matcher') }
'csrf'('request-matcher-ref':'matcher') mockBean(RequestMatcher,'matcher')
} createAppContext()
mockBean(RequestMatcher,'matcher') request.method = 'POST'
createAppContext() RequestMatcher matcher = appContext.getBean("matcher",RequestMatcher)
request.method = 'POST'
RequestMatcher matcher = appContext.getBean("matcher",RequestMatcher)
when: when:
when(matcher.matches(any(HttpServletRequest))).thenReturn(false) when(matcher.matches(any(HttpServletRequest))).thenReturn(false)
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.status == HttpServletResponse.SC_OK response.status == HttpServletResponse.SC_OK
when: when:
when(matcher.matches(any(HttpServletRequest))).thenReturn(true) when(matcher.matches(any(HttpServletRequest))).thenReturn(true)
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.status == HttpServletResponse.SC_FORBIDDEN response.status == HttpServletResponse.SC_FORBIDDEN
}
def "csrf csrfTokenRepository default delays save"() {
setup:
httpAutoConfig {
}
createAppContext()
request.method = "GET"
when:
springSecurityFilterChain.doFilter(request,response,chain)
then:
response.status == HttpServletResponse.SC_OK
request.getSession(false) == null
} }
def "csrf csrfTokenRepository"() { def "csrf csrfTokenRepository"() {
setup: setup:
httpAutoConfig { httpAutoConfig { 'csrf'('token-repository-ref':'repo') }
'csrf'('token-repository-ref':'repo') mockBean(CsrfTokenRepository,'repo')
} createAppContext()
mockBean(CsrfTokenRepository,'repo') CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
createAppContext() CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") request.setParameter(token.parameterName,token.token)
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.method = "POST"
request.setParameter(token.parameterName,token.token)
request.method = "POST"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.status == HttpServletResponse.SC_OK response.status == HttpServletResponse.SC_OK
when: when:
request.setParameter(token.parameterName,token.token+"INVALID") request.setParameter(token.parameterName,token.token+"INVALID")
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
response.status == HttpServletResponse.SC_FORBIDDEN response.status == HttpServletResponse.SC_FORBIDDEN
} }
def "csrf clears on login"() { def "csrf clears on login"() {
setup: setup:
httpAutoConfig { httpAutoConfig { 'csrf'('token-repository-ref':'repo') }
'csrf'('token-repository-ref':'repo') mockBean(CsrfTokenRepository,'repo')
} createAppContext()
mockBean(CsrfTokenRepository,'repo') CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
createAppContext() CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.generateToken(any(HttpServletRequest))).thenReturn(token)
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token)
when(repo.generateToken(any(HttpServletRequest))).thenReturn(token) request.method = "POST"
request.setParameter(token.parameterName,token.token) request.setParameter("username","user")
request.method = "POST" request.setParameter("password","password")
request.setParameter("username","user") request.servletPath = "/login"
request.setParameter("password","password")
request.servletPath = "/login"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
verify(repo, atLeastOnce()).saveToken(eq(null),any(HttpServletRequest), any(HttpServletResponse)) verify(repo, atLeastOnce()).saveToken(eq(null),any(HttpServletRequest), any(HttpServletResponse))
} }
def "csrf clears on logout"() { def "csrf clears on logout"() {
setup: setup:
httpAutoConfig { httpAutoConfig { 'csrf'('token-repository-ref':'repo') }
'csrf'('token-repository-ref':'repo') mockBean(CsrfTokenRepository,'repo')
} createAppContext()
mockBean(CsrfTokenRepository,'repo') CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
createAppContext() CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") request.setParameter(token.parameterName,token.token)
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.method = "POST"
request.setParameter(token.parameterName,token.token) request.servletPath = "/logout"
request.method = "POST"
request.servletPath = "/logout"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
verify(repo).saveToken(eq(null),any(HttpServletRequest), any(HttpServletResponse)) verify(repo).saveToken(eq(null),any(HttpServletRequest), any(HttpServletResponse))
} }
def "SEC-2495: csrf disables logout on GET"() { def "SEC-2495: csrf disables logout on GET"() {
setup: setup:
httpAutoConfig { httpAutoConfig { 'csrf'() }
'csrf'() createAppContext()
} login()
createAppContext() request.method = "GET"
login() request.requestURI = "/logout"
request.method = "GET"
request.requestURI = "/logout"
when: when:
springSecurityFilterChain.doFilter(request,response,chain) springSecurityFilterChain.doFilter(request,response,chain)
then: then:
getAuthentication(request) != null getAuthentication(request) != null
} }

View File

@ -378,6 +378,7 @@ You can find the highlights below:
* <<el-access-web-path-variables,Path Variables in Web Security Expressions>> * <<el-access-web-path-variables,Path Variables in Web Security Expressions>>
* <<headers-csp,Content Security Policy (CSP)>> * <<headers-csp,Content Security Policy (CSP)>>
* <<headers-hpkp,HTTP Public Key Pinning (HPKP)>> * <<headers-hpkp,HTTP Public Key Pinning (HPKP)>>
* <<csrf-cookie,CookieCsrfTokenRepository>> provides simple AngularJS & CSRF integration
* Added `ForwardAuthenticationFailureHandler` & `ForwardAuthenticationSuccessHandler` * Added `ForwardAuthenticationFailureHandler` & `ForwardAuthenticationSuccessHandler`
* SCrypt support with `SCryptPasswordEncoder` * SCrypt support with `SCryptPasswordEncoder`
* Meta Annotation Support * Meta Annotation Support
@ -3252,6 +3253,7 @@ protected void configure(HttpSecurity http) throws Exception {
} }
---- ----
[[csrf-include-csrf-token]] [[csrf-include-csrf-token]]
==== Include the CSRF Token ==== Include the CSRF Token
@ -3324,6 +3326,41 @@ name: $("meta[name='_csrf_header']").attr("content")
The configured client can be shared with any component of the application that needs to make a request to the CSRF protected resource. One significant different between rest.js and jQuery is that only requests made with the configured client will contain the CSRF token, vs jQuery where __all__ requests will include the token. The ability to scope which requests receive the token helps guard against leaking the CSRF token to a third party. Please refer to the https://github.com/cujojs/rest/tree/master/docs[rest.js reference documentation] for more information on rest.js. The configured client can be shared with any component of the application that needs to make a request to the CSRF protected resource. One significant different between rest.js and jQuery is that only requests made with the configured client will contain the CSRF token, vs jQuery where __all__ requests will include the token. The ability to scope which requests receive the token helps guard against leaking the CSRF token to a third party. Please refer to the https://github.com/cujojs/rest/tree/master/docs[rest.js reference documentation] for more information on rest.js.
[[csrf-cookie]]
===== CookieCsrfTokenRepository
There can be cases where users will want to persist the `CsrfToken` in a cookie.
By default the `CookieCsrfTokenRepository` will write to a cookie named `XSRF-TOKEN` and read it from a header named `X-XSRF-TOKEN` or the HTTP parameter `_csrf`.
These defaults come from https://docs.angularjs.org/api/ng/service/$http#cross-site-request-forgery-xsrf-protection[AngularJS]
You can configure `CookieCsrfTokenRepository` in XML using the following:
[source,xml]
----
<http>
<!-- ... -->
<csrf token-repository-ref="tokenRepository"/>
</http>
<b:bean id="tokenRepository" class="org.springframework.security.web.csrf.CookieCsrfTokenRepository"/>
----
You can configure `CookieCsrfTokenRepository` in Java Configuration using:
[source,java]
----
@EnableWebSecurity
public class WebSecurityConfig extends
WebSecurityConfigurerAdapter {
@Override
protected void configure(HttpSecurity http) throws Exception {
http
.csrf()
.csrfTokenRepository(new CookieCsrfTokenRepository());
}
}
----
[[csrf-caveats]] [[csrf-caveats]]
=== CSRF Caveats === CSRF Caveats
@ -3336,13 +3373,16 @@ One issue is that the expected CSRF token is stored in the HttpSession, so as so
[NOTE] [NOTE]
==== ====
One might ask why the expected `CsrfToken` isn't stored in a cookie. This is because there are known exploits in which headers (i.e. specify the cookies) can be set by another domain. This is the same reason Ruby on Rails http://weblog.rubyonrails.org/2011/2/8/csrf-protection-bypass-in-ruby-on-rails/[no longer skips CSRF checks when the header X-Requested-With is present]. See http://lists.webappsec.org/pipermail/websecurity_lists.webappsec.org/2011-February/007533.html[this webappsec.org thread] for details on how to perform the exploit. Another disadvantage is that by removing the state (i.e. the timeout) you lose the ability to forcibly terminate the token if it is compromised. One might ask why the expected `CsrfToken` isn't stored in a cookie by default. This is because there are known exploits in which headers (i.e. specify the cookies) can be set by another domain. This is the same reason Ruby on Rails http://weblog.rubyonrails.org/2011/2/8/csrf-protection-bypass-in-ruby-on-rails/[no longer skips CSRF checks when the header X-Requested-With is present]. See http://lists.webappsec.org/pipermail/websecurity_lists.webappsec.org/2011-February/007533.html[this webappsec.org thread] for details on how to perform the exploit. Another disadvantage is that by removing the state (i.e. the timeout) you lose the ability to forcibly terminate the token if it is compromised.
==== ====
A simple way to mitigate an active user experiencing a timeout is to have some JavaScript that lets the user know their session is about to expire. The user can click a button to continue and refresh the session. A simple way to mitigate an active user experiencing a timeout is to have some JavaScript that lets the user know their session is about to expire. The user can click a button to continue and refresh the session.
Alternatively, specifying a custom `AccessDeniedHandler` allows you to process the `InvalidCsrfTokenException` any way you like. For an example of how to customize the `AccessDeniedHandler` refer to the provided links for both <<nsa-access-denied-handler,xml>> and https://github.com/spring-projects/spring-security/blob/3.2.0.RC1/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpAccessDeniedHandlerTests.groovy#L64[Java configuration]. Alternatively, specifying a custom `AccessDeniedHandler` allows you to process the `InvalidCsrfTokenException` any way you like. For an example of how to customize the `AccessDeniedHandler` refer to the provided links for both <<nsa-access-denied-handler,xml>> and https://github.com/spring-projects/spring-security/blob/3.2.0.RC1/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpAccessDeniedHandlerTests.groovy#L64[Java configuration].
Finally, the application can be configured to use <<csrf-cookie,CookieCsrfTokenRepository>> which will not expire.
As previously mentioned, this is not as secure as using a session, but in many cases can be good enough.
[[csrf-login]] [[csrf-login]]
==== Logging In ==== Logging In

View File

@ -53,6 +53,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.test.web.servlet.request.RequestPostProcessor;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -107,8 +108,8 @@ public final class SecurityMockMvcRequestPostProcessors {
* @throws IOException * @throws IOException
* @throws CertificateException * @throws CertificateException
*/ */
public static RequestPostProcessor x509(String resourceName) throws IOException, public static RequestPostProcessor x509(String resourceName)
CertificateException { throws IOException, CertificateException {
ResourceLoader loader = new DefaultResourceLoader(); ResourceLoader loader = new DefaultResourceLoader();
Resource resource = loader.getResource(resourceName); Resource resource = loader.getResource(resourceName);
InputStream inputStream = resource.getInputStream(); InputStream inputStream = resource.getInputStream();
@ -142,24 +143,24 @@ public final class SecurityMockMvcRequestPostProcessors {
* Establish a {@link SecurityContext} that has a * Establish a {@link SecurityContext} that has a
* {@link UsernamePasswordAuthenticationToken} for the * {@link UsernamePasswordAuthenticationToken} for the
* {@link Authentication#getPrincipal()} and a {@link User} for the * {@link Authentication#getPrincipal()} and a {@link User} for the
* {@link UsernamePasswordAuthenticationToken#getPrincipal()}. All details * {@link UsernamePasswordAuthenticationToken#getPrincipal()}. All details are
* are declarative and do not require that the user actually exists. * declarative and do not require that the user actually exists.
* *
* <p> * <p>
* The support works by associating the user to the HttpServletRequest. To * The support works by associating the user to the HttpServletRequest. To associate
* associate the request to the SecurityContextHolder you need to ensure * the request to the SecurityContextHolder you need to ensure that the
* that the SecurityContextPersistenceFilter is associated with the * SecurityContextPersistenceFilter is associated with the MockMvc instance. A few
* MockMvc instance. A few ways to do this are: * ways to do this are:
* </p> * </p>
* *
* <ul> * <ul>
* <li>Invoking apply {@link SecurityMockMvcConfigurers#springSecurity()}</li> * <li>Invoking apply {@link SecurityMockMvcConfigurers#springSecurity()}</li>
* <li>Adding Spring Security's FilterChainProxy to MockMvc</li> * <li>Adding Spring Security's FilterChainProxy to MockMvc</li>
* <li>Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc instance may make sense when using MockMvcBuilders standaloneSetup</li> * <li>Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc
* instance may make sense when using MockMvcBuilders standaloneSetup</li>
* </ul> * </ul>
* *
* @param username * @param username the username to populate
* the username to populate
* @return the {@link UserRequestPostProcessor} for additional customization * @return the {@link UserRequestPostProcessor} for additional customization
*/ */
public static UserRequestPostProcessor user(String username) { public static UserRequestPostProcessor user(String username) {
@ -174,16 +175,17 @@ public final class SecurityMockMvcRequestPostProcessors {
* declarative and do not require that the user actually exists. * declarative and do not require that the user actually exists.
* *
* <p> * <p>
* The support works by associating the user to the HttpServletRequest. To * The support works by associating the user to the HttpServletRequest. To associate
* associate the request to the SecurityContextHolder you need to ensure * the request to the SecurityContextHolder you need to ensure that the
* that the SecurityContextPersistenceFilter is associated with the * SecurityContextPersistenceFilter is associated with the MockMvc instance. A few
* MockMvc instance. A few ways to do this are: * ways to do this are:
* </p> * </p>
* *
* <ul> * <ul>
* <li>Invoking apply {@link SecurityMockMvcConfigurers#springSecurity()}</li> * <li>Invoking apply {@link SecurityMockMvcConfigurers#springSecurity()}</li>
* <li>Adding Spring Security's FilterChainProxy to MockMvc</li> * <li>Adding Spring Security's FilterChainProxy to MockMvc</li>
* <li>Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc instance may make sense when using MockMvcBuilders standaloneSetup</li> * <li>Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc
* instance may make sense when using MockMvcBuilders standaloneSetup</li>
* </ul> * </ul>
* *
* @param user the UserDetails to populate * @param user the UserDetails to populate
@ -199,16 +201,17 @@ public final class SecurityMockMvcRequestPostProcessors {
* details are declarative and do not require that the user actually exists. * details are declarative and do not require that the user actually exists.
* *
* <p> * <p>
* The support works by associating the user to the HttpServletRequest. To * The support works by associating the user to the HttpServletRequest. To associate
* associate the request to the SecurityContextHolder you need to ensure * the request to the SecurityContextHolder you need to ensure that the
* that the SecurityContextPersistenceFilter is associated with the * SecurityContextPersistenceFilter is associated with the MockMvc instance. A few
* MockMvc instance. A few ways to do this are: * ways to do this are:
* </p> * </p>
* *
* <ul> * <ul>
* <li>Invoking apply {@link SecurityMockMvcConfigurers#springSecurity()}</li> * <li>Invoking apply {@link SecurityMockMvcConfigurers#springSecurity()}</li>
* <li>Adding Spring Security's FilterChainProxy to MockMvc</li> * <li>Adding Spring Security's FilterChainProxy to MockMvc</li>
* <li>Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc instance may make sense when using MockMvcBuilders standaloneSetup</li> * <li>Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc
* instance may make sense when using MockMvcBuilders standaloneSetup</li>
* </ul> * </ul>
* *
* @param authentication the Authentication to populate * @param authentication the Authentication to populate
@ -220,9 +223,9 @@ public final class SecurityMockMvcRequestPostProcessors {
/** /**
* Establish a {@link SecurityContext} that uses an * Establish a {@link SecurityContext} that uses an
* {@link AnonymousAuthenticationToken}. This is useful when a user wants to * {@link AnonymousAuthenticationToken}. This is useful when a user wants to run a
* run a majority of tests as a specific user and wishes to override a few * majority of tests as a specific user and wishes to override a few methods to be
* methods to be anonymous. For example: * anonymous. For example:
* *
* <pre> * <pre>
* <code> * <code>
@ -241,8 +244,7 @@ public final class SecurityMockMvcRequestPostProcessors {
* } * }
* // ... lots of tests ran with a default user ... * // ... lots of tests ran with a default user ...
* } * }
* </code> * </code> </pre>
* </pre>
* *
* @return the {@link RequestPostProcessor} to use * @return the {@link RequestPostProcessor} to use
*/ */
@ -254,11 +256,10 @@ public final class SecurityMockMvcRequestPostProcessors {
* Establish the specified {@link SecurityContext} to be used. * Establish the specified {@link SecurityContext} to be used.
* *
* <p> * <p>
* This works by associating the user to the {@link HttpServletRequest}. To * This works by associating the user to the {@link HttpServletRequest}. To associate
* associate the request to the {@link SecurityContextHolder} you need to * the request to the {@link SecurityContextHolder} you need to ensure that the
* ensure that the {@link SecurityContextPersistenceFilter} (i.e. Spring * {@link SecurityContextPersistenceFilter} (i.e. Spring Security's FilterChainProxy
* Security's FilterChainProxy will typically do this) is associated with * will typically do this) is associated with the {@link MockMvc} instance.
* the {@link MockMvc} instance.
* </p> * </p>
*/ */
public static RequestPostProcessor securityContext(SecurityContext securityContext) { public static RequestPostProcessor securityContext(SecurityContext securityContext) {
@ -289,8 +290,10 @@ public final class SecurityMockMvcRequestPostProcessors {
this.certificates = certificates; this.certificates = certificates;
} }
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
request.setAttribute("javax.servlet.request.X509Certificate", certificates); request.setAttribute("javax.servlet.request.X509Certificate",
this.certificates);
return request; return request;
} }
} }
@ -313,18 +316,20 @@ public final class SecurityMockMvcRequestPostProcessors {
* @see org.springframework.test.web.servlet.request.RequestPostProcessor * @see org.springframework.test.web.servlet.request.RequestPostProcessor
* #postProcessRequest (org.springframework.mock.web.MockHttpServletRequest) * #postProcessRequest (org.springframework.mock.web.MockHttpServletRequest)
*/ */
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
if(!(repository instanceof TestCsrfTokenRepository)) { if (!(repository instanceof TestCsrfTokenRepository)) {
repository = new TestCsrfTokenRepository(repository); repository = new TestCsrfTokenRepository(
new HttpSessionCsrfTokenRepository());
WebTestUtils.setCsrfTokenRepository(request, repository); WebTestUtils.setCsrfTokenRepository(request, repository);
} }
CsrfToken token = repository.generateToken(request); CsrfToken token = repository.generateToken(request);
repository.saveToken(token, request, new MockHttpServletResponse()); repository.saveToken(token, request, new MockHttpServletResponse());
String tokenValue = useInvalidToken ? "invalid" + token.getToken() : token String tokenValue = this.useInvalidToken ? "invalid" + token.getToken()
.getToken(); : token.getToken();
if (asHeader) { if (this.asHeader) {
request.addHeader(token.getHeaderName(), tokenValue); request.addHeader(token.getHeaderName(), tokenValue);
} }
else { else {
@ -357,16 +362,13 @@ public final class SecurityMockMvcRequestPostProcessors {
private CsrfRequestPostProcessor() { private CsrfRequestPostProcessor() {
} }
/** /**
* Used to wrap the CsrfTokenRepository to provide support for testing * Used to wrap the CsrfTokenRepository to provide support for testing when the
* when the request is wrapped (i.e. Spring Session is in use). * request is wrapped (i.e. Spring Session is in use).
*/ */
static class TestCsrfTokenRepository implements static class TestCsrfTokenRepository implements CsrfTokenRepository {
CsrfTokenRepository { final static String ATTR_NAME = TestCsrfTokenRepository.class.getName()
final static String ATTR_NAME = TestCsrfTokenRepository.class .concat(".TOKEN");
.getName().concat(".TOKEN");
private final CsrfTokenRepository delegate; private final CsrfTokenRepository delegate;
@ -374,14 +376,18 @@ public final class SecurityMockMvcRequestPostProcessors {
this.delegate = delegate; this.delegate = delegate;
} }
@Override
public CsrfToken generateToken(HttpServletRequest request) { public CsrfToken generateToken(HttpServletRequest request) {
return delegate.generateToken(request); return this.delegate.generateToken(request);
} }
public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { @Override
public void saveToken(CsrfToken token, HttpServletRequest request,
HttpServletResponse response) {
request.setAttribute(ATTR_NAME, token); request.setAttribute(ATTR_NAME, token);
} }
@Override
public CsrfToken loadToken(HttpServletRequest request) { public CsrfToken loadToken(HttpServletRequest request) {
return (CsrfToken) request.getAttribute(ATTR_NAME); return (CsrfToken) request.getAttribute(ATTR_NAME);
} }
@ -447,14 +453,16 @@ public final class SecurityMockMvcRequestPostProcessors {
private String createAuthorizationHeader(MockHttpServletRequest request) { private String createAuthorizationHeader(MockHttpServletRequest request) {
String uri = request.getRequestURI(); String uri = request.getRequestURI();
String responseDigest = generateDigest(username, realm, password, String responseDigest = generateDigest(this.username, this.realm,
request.getMethod(), uri, qop, nonce, nc, cnonce); this.password, request.getMethod(), uri, this.qop, this.nonce,
return "Digest username=\"" + username + "\", realm=\"" + realm this.nc, this.cnonce);
+ "\", nonce=\"" + nonce + "\", uri=\"" + uri + "\", response=\"" return "Digest username=\"" + this.username + "\", realm=\"" + this.realm
+ responseDigest + "\", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + "\", nonce=\"" + this.nonce + "\", uri=\"" + uri + "\", response=\""
+ cnonce + "\""; + responseDigest + "\", qop=" + this.qop + ", nc=" + this.nc
+ ", cnonce=\"" + this.cnonce + "\"";
} }
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
request.addHeader("Authorization", createAuthorizationHeader(request)); request.addHeader("Authorization", createAuthorizationHeader(request));
@ -574,8 +582,7 @@ public final class SecurityMockMvcRequestPostProcessors {
* Used to wrap the SecurityContextRepository to provide support for testing in * Used to wrap the SecurityContextRepository to provide support for testing in
* stateless mode * stateless mode
*/ */
static class TestSecurityContextRepository implements static class TestSecurityContextRepository implements SecurityContextRepository {
SecurityContextRepository {
private final static String ATTR_NAME = TestSecurityContextRepository.class private final static String ATTR_NAME = TestSecurityContextRepository.class
.getName().concat(".REPO"); .getName().concat(".REPO");
@ -585,6 +592,7 @@ public final class SecurityMockMvcRequestPostProcessors {
this.delegate = delegate; this.delegate = delegate;
} }
@Override
public SecurityContext loadContext( public SecurityContext loadContext(
HttpRequestResponseHolder requestResponseHolder) { HttpRequestResponseHolder requestResponseHolder) {
SecurityContext result = getContext(requestResponseHolder.getRequest()); SecurityContext result = getContext(requestResponseHolder.getRequest());
@ -592,19 +600,22 @@ public final class SecurityMockMvcRequestPostProcessors {
// holder are updated // holder are updated
// remember the SecurityContextRepository is used in many different // remember the SecurityContextRepository is used in many different
// locations // locations
SecurityContext delegateResult = delegate SecurityContext delegateResult = this.delegate
.loadContext(requestResponseHolder); .loadContext(requestResponseHolder);
return result == null ? delegateResult : result; return result == null ? delegateResult : result;
} }
@Override
public void saveContext(SecurityContext context, HttpServletRequest request, public void saveContext(SecurityContext context, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
request.setAttribute(ATTR_NAME, context); request.setAttribute(ATTR_NAME, context);
delegate.saveContext(context, request, response); this.delegate.saveContext(context, request, response);
} }
@Override
public boolean containsContext(HttpServletRequest request) { public boolean containsContext(HttpServletRequest request) {
return getContext(request) != null || delegate.containsContext(request); return getContext(request) != null
|| this.delegate.containsContext(request);
} }
private static SecurityContext getContext(HttpServletRequest request) { private static SecurityContext getContext(HttpServletRequest request) {
@ -625,15 +636,17 @@ public final class SecurityMockMvcRequestPostProcessors {
SecurityContextRequestPostProcessorSupport implements RequestPostProcessor { SecurityContextRequestPostProcessorSupport implements RequestPostProcessor {
private SecurityContext EMPTY = SecurityContextHolder.createEmptyContext(); private SecurityContext EMPTY = SecurityContextHolder.createEmptyContext();
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
// TestSecurityContextHolder is only a default value // TestSecurityContextHolder is only a default value
SecurityContext existingContext = TestSecurityContextRepository.getContext(request); SecurityContext existingContext = TestSecurityContextRepository
if(existingContext != null) { .getContext(request);
if (existingContext != null) {
return request; return request;
} }
SecurityContext context = TestSecurityContextHolder.getContext(); SecurityContext context = TestSecurityContextHolder.getContext();
if(!EMPTY.equals(context)) { if (!this.EMPTY.equals(context)) {
save(context, request); save(context, request);
} }
@ -657,6 +670,7 @@ public final class SecurityMockMvcRequestPostProcessors {
this.securityContext = securityContext; this.securityContext = securityContext;
} }
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
save(this.securityContext, request); save(this.securityContext, request);
return request; return request;
@ -679,10 +693,11 @@ public final class SecurityMockMvcRequestPostProcessors {
this.authentication = authentication; this.authentication = authentication;
} }
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication); context.setAuthentication(this.authentication);
save(authentication, request); save(this.authentication, request);
return request; return request;
} }
} }
@ -695,19 +710,20 @@ public final class SecurityMockMvcRequestPostProcessors {
* @author Rob Winch * @author Rob Winch
* @since 4.0 * @since 4.0
*/ */
private final static class UserDetailsRequestPostProcessor implements private final static class UserDetailsRequestPostProcessor
RequestPostProcessor { implements RequestPostProcessor {
private final RequestPostProcessor delegate; private final RequestPostProcessor delegate;
public UserDetailsRequestPostProcessor(UserDetails user) { public UserDetailsRequestPostProcessor(UserDetails user) {
Authentication token = new UsernamePasswordAuthenticationToken(user, Authentication token = new UsernamePasswordAuthenticationToken(user,
user.getPassword(), user.getAuthorities()); user.getPassword(), user.getAuthorities());
delegate = new AuthenticationRequestPostProcessor(token); this.delegate = new AuthenticationRequestPostProcessor(token);
} }
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
return delegate.postProcessRequest(request); return this.delegate.postProcessRequest(request);
} }
} }
@ -752,8 +768,8 @@ public final class SecurityMockMvcRequestPostProcessors {
* {@link #authorities(GrantedAuthority...)}, but just not as flexible. * {@link #authorities(GrantedAuthority...)}, but just not as flexible.
* *
* @param roles The roles to populate. Note that if the role does not start with * @param roles The roles to populate. Note that if the role does not start with
* {@link #ROLE_PREFIX} it will automatically be prepended. This means by * {@link #ROLE_PREFIX} it will automatically be prepended. This means by default
* default {@code roles("ROLE_USER")} and {@code roles("USER")} are equivalent. * {@code roles("ROLE_USER")} and {@code roles("USER")} are equivalent.
* @see #authorities(GrantedAuthority...) * @see #authorities(GrantedAuthority...)
* @see #ROLE_PREFIX * @see #ROLE_PREFIX
* @return the UserRequestPostProcessor for further customizations * @return the UserRequestPostProcessor for further customizations
@ -764,8 +780,7 @@ public final class SecurityMockMvcRequestPostProcessors {
for (String role : roles) { for (String role : roles) {
if (role.startsWith(ROLE_PREFIX)) { if (role.startsWith(ROLE_PREFIX)) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Role should not start with " "Role should not start with " + ROLE_PREFIX
+ ROLE_PREFIX
+ " since this method automatically prefixes with this value. Got " + " since this method automatically prefixes with this value. Got "
+ role); + role);
} }
@ -812,6 +827,7 @@ public final class SecurityMockMvcRequestPostProcessors {
return this; return this;
} }
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
UserDetailsRequestPostProcessor delegate = new UserDetailsRequestPostProcessor( UserDetailsRequestPostProcessor delegate = new UserDetailsRequestPostProcessor(
createUser()); createUser());
@ -823,19 +839,27 @@ public final class SecurityMockMvcRequestPostProcessors {
* @return the {@link User} for the principal * @return the {@link User} for the principal
*/ */
private User createUser() { private User createUser() {
return new User(username, password, enabled, accountNonExpired, return new User(this.username, this.password, this.enabled,
credentialsNonExpired, accountNonLocked, authorities); this.accountNonExpired, this.credentialsNonExpired,
this.accountNonLocked, this.authorities);
} }
} }
private static class AnonymousRequestPostProcessor extends SecurityContextRequestPostProcessorSupport implements RequestPostProcessor { private static class AnonymousRequestPostProcessor extends
private AuthenticationRequestPostProcessor delegate = new AuthenticationRequestPostProcessor(new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))); SecurityContextRequestPostProcessorSupport implements RequestPostProcessor {
private AuthenticationRequestPostProcessor delegate = new AuthenticationRequestPostProcessor(
new AnonymousAuthenticationToken("key", "anonymous",
AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")));
/* (non-Javadoc) /*
* @see org.springframework.test.web.servlet.request.RequestPostProcessor#postProcessRequest(org.springframework.mock.web.MockHttpServletRequest) * (non-Javadoc)
*
* @see org.springframework.test.web.servlet.request.RequestPostProcessor#
* postProcessRequest(org.springframework.mock.web.MockHttpServletRequest)
*/ */
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
return delegate.postProcessRequest(request); return this.delegate.postProcessRequest(request);
} }
} }
@ -853,8 +877,9 @@ public final class SecurityMockMvcRequestPostProcessors {
this.headerValue = "Basic " + new String(Base64.encode(toEncode)); this.headerValue = "Basic " + new String(Base64.encode(toEncode));
} }
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
request.addHeader("Authorization", headerValue); request.addHeader("Authorization", this.headerValue);
return request; return request;
} }
} }

View File

@ -0,0 +1,126 @@
/*
* Copyright 2012-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
import java.util.UUID;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.util.WebUtils;
/**
* A {@link CsrfTokenRepository} that persist the CSRF token in a cookie named
* "XSRF-TOKEN" and reads from the header "X-XSRF-TOKEN" following the conventions of
* AngularJS.
*
* @author Rob Winch
* @since 4.1
*/
public final class CookieCsrfTokenRepository implements CsrfTokenRepository {
static final String DEFAULT_CSRF_COOKIE_NAME = "XSRF-TOKEN";
static final String DEFAULT_CSRF_PARAMETER_NAME = "_csrf";
static final String DEFAULT_CSRF_HEADER_NAME = "X-XSRF-TOKEN";
private String parameterName = DEFAULT_CSRF_PARAMETER_NAME;
private String headerName = DEFAULT_CSRF_HEADER_NAME;
private String cookieName = DEFAULT_CSRF_COOKIE_NAME;
@Override
public CsrfToken generateToken(HttpServletRequest request) {
return new DefaultCsrfToken(this.headerName, this.parameterName,
createNewToken());
}
@Override
public void saveToken(CsrfToken token, HttpServletRequest request,
HttpServletResponse response) {
String tokenValue = token == null ? "" : token.getToken();
Cookie cookie = new Cookie(this.cookieName, tokenValue);
cookie.setSecure(request.isSecure());
cookie.setPath(getCookiePath(request));
if (token == null) {
cookie.setMaxAge(0);
}
else {
cookie.setMaxAge(-1);
}
response.addCookie(cookie);
}
@Override
public CsrfToken loadToken(HttpServletRequest request) {
Cookie cookie = WebUtils.getCookie(request, this.cookieName);
if (cookie == null) {
return null;
}
String token = cookie.getValue();
if (!StringUtils.hasLength(token)) {
return null;
}
return new DefaultCsrfToken(this.headerName, this.parameterName, token);
}
/**
* Sets the name of the HTTP request parameter that should be used to provide a token.
*
* @param parameterName the name of the HTTP request parameter that should be used to
* provide a token
*/
public void setParameterName(String parameterName) {
Assert.notNull(parameterName, "parameterName is not null");
this.parameterName = parameterName;
}
/**
* Sets the name of the HTTP header that should be used to provide the token
*
* @param headerName the name of the HTTP header that should be used to provide the
* token
*/
public void setHeaderName(String headerName) {
Assert.notNull(headerName, "headerName is not null");
this.headerName = headerName;
}
/**
* Sets the name of the cookie that the expected CSRF token is saved to and read from
*
* @param cookieName the name of the cookie that the expected CSRF token is saved to
* and read from
*/
public void setCookieName(String cookieName) {
Assert.notNull(cookieName, "cookieName is not null");
this.cookieName = cookieName;
}
private String getCookiePath(HttpServletRequest request) {
String contextPath = request.getContextPath();
return contextPath.length() > 0 ? contextPath : "/";
}
private String createNewToken() {
return UUID.randomUUID().toString();
}
}

View File

@ -52,6 +52,7 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
* #onAuthentication(org.springframework.security.core.Authentication, * #onAuthentication(org.springframework.security.core.Authentication,
* javax.servlet.http.HttpServletRequest, javax.servlet.http.HttpServletResponse) * javax.servlet.http.HttpServletRequest, javax.servlet.http.HttpServletResponse)
*/ */
@Override
public void onAuthentication(Authentication authentication, public void onAuthentication(Authentication authentication,
HttpServletRequest request, HttpServletResponse response) HttpServletRequest request, HttpServletResponse response)
throws SessionAuthenticationException { throws SessionAuthenticationException {
@ -60,96 +61,10 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
this.csrfTokenRepository.saveToken(null, request, response); this.csrfTokenRepository.saveToken(null, request, response);
CsrfToken newToken = this.csrfTokenRepository.generateToken(request); CsrfToken newToken = this.csrfTokenRepository.generateToken(request);
CsrfToken tokenForRequest = new SaveOnAccessCsrfToken( this.csrfTokenRepository.saveToken(newToken, request, response);
this.csrfTokenRepository, request, response, newToken);
request.setAttribute(CsrfToken.class.getName(), tokenForRequest); request.setAttribute(CsrfToken.class.getName(), newToken);
request.setAttribute(newToken.getParameterName(), tokenForRequest); request.setAttribute(newToken.getParameterName(), newToken);
} }
} }
private static final class SaveOnAccessCsrfToken implements CsrfToken {
private transient CsrfTokenRepository tokenRepository;
private transient HttpServletRequest request;
private transient HttpServletResponse response;
private final CsrfToken delegate;
public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository,
HttpServletRequest request, HttpServletResponse response,
CsrfToken delegate) {
super();
this.tokenRepository = tokenRepository;
this.request = request;
this.response = response;
this.delegate = delegate;
}
public String getHeaderName() {
return this.delegate.getHeaderName();
}
public String getParameterName() {
return this.delegate.getParameterName();
}
public String getToken() {
saveTokenIfNecessary();
return this.delegate.getToken();
}
@Override
public String toString() {
return "SaveOnAccessCsrfToken [delegate=" + this.delegate + "]";
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result
+ ((this.delegate == null) ? 0 : this.delegate.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj;
if (this.delegate == null) {
if (other.delegate != null) {
return false;
}
}
else if (!this.delegate.equals(other.delegate)) {
return false;
}
return true;
}
private void saveTokenIfNecessary() {
if (this.tokenRepository == null) {
return;
}
synchronized (this) {
if (this.tokenRepository != null) {
this.tokenRepository.saveToken(this.delegate, this.request,
this.response);
this.tokenRepository = null;
this.request = null;
this.response = null;
}
}
}
}
} }

View File

@ -27,27 +27,29 @@ import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.access.AccessDeniedHandlerImpl; import org.springframework.security.web.access.AccessDeniedHandlerImpl;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.UrlUtils;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
/** /**
* <p> * <p>
* Applies <a href="https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)" * Applies
* >CSRF</a> protection using a synchronizer token pattern. Developers are required to * <a href="https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)" >CSRF</a>
* ensure that {@link CsrfFilter} is invoked for any request that allows state to change. * protection using a synchronizer token pattern. Developers are required to ensure that
* Typically this just means that they should ensure their web application follows proper * {@link CsrfFilter} is invoked for any request that allows state to change. Typically
* REST semantics (i.e. do not change state with the HTTP methods GET, HEAD, TRACE, * this just means that they should ensure their web application follows proper REST
* OPTIONS). * semantics (i.e. do not change state with the HTTP methods GET, HEAD, TRACE, OPTIONS).
* </p> * </p>
* *
* <p> * <p>
* Typically the {@link CsrfTokenRepository} implementation chooses to store the * Typically the {@link CsrfTokenRepository} implementation chooses to store the
* {@link CsrfToken} in {@link HttpSession} with {@link HttpSessionCsrfTokenRepository}. * {@link CsrfToken} in {@link HttpSession} with {@link HttpSessionCsrfTokenRepository}
* This is preferred to storing the token in a cookie which can be modified by a client application. * wrapped by a {@link LazyCsrfTokenRepository}. This is preferred to storing the token in
* a cookie which can be modified by a client application.
* </p> * </p>
* *
* @author Rob Winch * @author Rob Winch
@ -82,18 +84,19 @@ public final class CsrfFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain) HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
CsrfToken csrfToken = tokenRepository.loadToken(request); request.setAttribute(HttpServletResponse.class.getName(), response);
CsrfToken csrfToken = this.tokenRepository.loadToken(request);
final boolean missingToken = csrfToken == null; final boolean missingToken = csrfToken == null;
if (missingToken) { if (missingToken) {
CsrfToken generatedToken = tokenRepository.generateToken(request); csrfToken = this.tokenRepository.generateToken(request);
csrfToken = new SaveOnAccessCsrfToken(tokenRepository, request, response, this.tokenRepository.saveToken(csrfToken, request, response);
generatedToken);
} }
request.setAttribute(CsrfToken.class.getName(), csrfToken); request.setAttribute(CsrfToken.class.getName(), csrfToken);
request.setAttribute(csrfToken.getParameterName(), csrfToken); request.setAttribute(csrfToken.getParameterName(), csrfToken);
if (!requireCsrfProtectionMatcher.matches(request)) { if (!this.requireCsrfProtectionMatcher.matches(request)) {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
return; return;
} }
@ -103,16 +106,16 @@ public final class CsrfFilter extends OncePerRequestFilter {
actualToken = request.getParameter(csrfToken.getParameterName()); actualToken = request.getParameter(csrfToken.getParameterName());
} }
if (!csrfToken.getToken().equals(actualToken)) { if (!csrfToken.getToken().equals(actualToken)) {
if (logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
logger.debug("Invalid CSRF token found for " this.logger.debug("Invalid CSRF token found for "
+ UrlUtils.buildFullRequestUrl(request)); + UrlUtils.buildFullRequestUrl(request));
} }
if (missingToken) { if (missingToken) {
accessDeniedHandler.handle(request, response, this.accessDeniedHandler.handle(request, response,
new MissingCsrfTokenException(actualToken)); new MissingCsrfTokenException(actualToken));
} }
else { else {
accessDeniedHandler.handle(request, response, this.accessDeniedHandler.handle(request, response,
new InvalidCsrfTokenException(csrfToken, actualToken)); new InvalidCsrfTokenException(csrfToken, actualToken));
} }
return; return;
@ -156,87 +159,9 @@ public final class CsrfFilter extends OncePerRequestFilter {
this.accessDeniedHandler = accessDeniedHandler; this.accessDeniedHandler = accessDeniedHandler;
} }
@SuppressWarnings("serial")
private static final class SaveOnAccessCsrfToken implements CsrfToken {
private transient CsrfTokenRepository tokenRepository;
private transient HttpServletRequest request;
private transient HttpServletResponse response;
private final CsrfToken delegate;
public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository,
HttpServletRequest request, HttpServletResponse response,
CsrfToken delegate) {
super();
this.tokenRepository = tokenRepository;
this.request = request;
this.response = response;
this.delegate = delegate;
}
public String getHeaderName() {
return delegate.getHeaderName();
}
public String getParameterName() {
return delegate.getParameterName();
}
public String getToken() {
saveTokenIfNecessary();
return delegate.getToken();
}
@Override
public String toString() {
return "SaveOnAccessCsrfToken [delegate=" + delegate + "]";
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((delegate == null) ? 0 : delegate.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj;
if (delegate == null) {
if (other.delegate != null)
return false;
}
else if (!delegate.equals(other.delegate))
return false;
return true;
}
private void saveTokenIfNecessary() {
if (this.tokenRepository == null) {
return;
}
synchronized (this) {
if (tokenRepository != null) {
this.tokenRepository.saveToken(delegate, request, response);
this.tokenRepository = null;
this.request = null;
this.response = null;
}
}
}
}
private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
private final HashSet<String> allowedMethods = new HashSet<String>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS")); private final HashSet<String> allowedMethods = new HashSet<String>(
Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS"));
/* /*
* (non-Javadoc) * (non-Javadoc)
@ -245,8 +170,9 @@ public final class CsrfFilter extends OncePerRequestFilter {
* org.springframework.security.web.util.matcher.RequestMatcher#matches(javax. * org.springframework.security.web.util.matcher.RequestMatcher#matches(javax.
* servlet.http.HttpServletRequest) * servlet.http.HttpServletRequest)
*/ */
@Override
public boolean matches(HttpServletRequest request) { public boolean matches(HttpServletRequest request) {
return !allowedMethods.contains(request.getMethod()); return !this.allowedMethods.contains(request.getMethod());
} }
} }
} }

View File

@ -0,0 +1,186 @@
/*
* Copyright 2012-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.util.Assert;
/**
* A {@link CsrfTokenRepository} that delays saving new {@link CsrfToken} until the
* attributes of the {@link CsrfToken} that were generated are accessed.
*
* @author Rob Winch
* @since 4.1
*/
public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
/**
* The {@link HttpServletRequest} attribute name that the {@link HttpServletResponse}
* must be on.
*/
private static final String HTTP_RESPONSE_ATTR = HttpServletResponse.class.getName();
private final CsrfTokenRepository delegate;
/**
* Creates a new instance
* @param delegate the {@link CsrfTokenRepository} to use. Cannot be null
* @throws IllegalArgumentException if delegate is null.
*/
public LazyCsrfTokenRepository(CsrfTokenRepository delegate) {
Assert.notNull(delegate, "delegate cannot be null");
this.delegate = delegate;
}
/**
* Generates a new token
* @param request the {@link HttpServletRequest} to use. The
* {@link HttpServletRequest} must have the {@link HttpServletResponse} as an
* attribute with the name of <code>HttpServletResponse.class.getName()</code>
*/
@Override
public CsrfToken generateToken(HttpServletRequest request) {
return wrap(request, this.delegate.generateToken(request));
}
/**
* Does nothing if the {@link CsrfToken} is not null. Saving is done only when the
* {@link CsrfToken#getToken()} iis accessed from
* {@link #generateToken(HttpServletRequest)}. If it is null, then the save is
* performed immediately.
*/
@Override
public void saveToken(CsrfToken token, HttpServletRequest request,
HttpServletResponse response) {
if (token == null) {
this.delegate.saveToken(token, request, response);
}
}
/**
* Delegates to the injected {@link CsrfTokenRepository}
*/
@Override
public CsrfToken loadToken(HttpServletRequest request) {
return this.delegate.loadToken(request);
}
private CsrfToken wrap(HttpServletRequest request, CsrfToken token) {
HttpServletResponse response = getResponse(request);
return new SaveOnAccessCsrfToken(this.delegate, request, response, token);
}
private HttpServletResponse getResponse(HttpServletRequest request) {
HttpServletResponse response = (HttpServletResponse) request
.getAttribute(HTTP_RESPONSE_ATTR);
if (response == null) {
throw new IllegalArgumentException(
"The HttpServletRequest attribute must contain an HttpServletResponse for the attribute "
+ HTTP_RESPONSE_ATTR);
}
return response;
}
private static final class SaveOnAccessCsrfToken implements CsrfToken {
private transient CsrfTokenRepository tokenRepository;
private transient HttpServletRequest request;
private transient HttpServletResponse response;
private final CsrfToken delegate;
SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository,
HttpServletRequest request, HttpServletResponse response,
CsrfToken delegate) {
super();
this.tokenRepository = tokenRepository;
this.request = request;
this.response = response;
this.delegate = delegate;
}
@Override
public String getHeaderName() {
return this.delegate.getHeaderName();
}
@Override
public String getParameterName() {
return this.delegate.getParameterName();
}
@Override
public String getToken() {
saveTokenIfNecessary();
return this.delegate.getToken();
}
@Override
public String toString() {
return "SaveOnAccessCsrfToken [delegate=" + this.delegate + "]";
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result
+ ((this.delegate == null) ? 0 : this.delegate.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj;
if (this.delegate == null) {
if (other.delegate != null) {
return false;
}
}
else if (!this.delegate.equals(other.delegate)) {
return false;
}
return true;
}
private void saveTokenIfNecessary() {
if (this.tokenRepository == null) {
return;
}
synchronized (this) {
if (this.tokenRepository != null) {
this.tokenRepository.saveToken(this.delegate, this.request,
this.response);
this.tokenRepository = null;
this.request = null;
this.response = null;
}
}
}
}
}

View File

@ -0,0 +1,189 @@
/*
* Copyright 2012-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
import javax.servlet.http.Cookie;
import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Rob Winch
* @since 4.1
*/
public class CookieCsrfTokenRepositoryTests {
CookieCsrfTokenRepository repository;
MockHttpServletResponse response;
MockHttpServletRequest request;
@Before
public void setup() {
this.repository = new CookieCsrfTokenRepository();
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.request.setContextPath("/context");
}
@Test
public void generateToken() {
CsrfToken generateToken = this.repository.generateToken(this.request);
assertThat(generateToken).isNotNull();
assertThat(generateToken.getHeaderName())
.isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_HEADER_NAME);
assertThat(generateToken.getParameterName())
.isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_PARAMETER_NAME);
assertThat(generateToken.getToken()).isNotEmpty();
}
@Test
public void generateTokenCustom() {
String headerName = "headerName";
String parameterName = "paramName";
this.repository.setHeaderName(headerName);
this.repository.setParameterName(parameterName);
CsrfToken generateToken = this.repository.generateToken(this.request);
assertThat(generateToken).isNotNull();
assertThat(generateToken.getHeaderName()).isEqualTo(headerName);
assertThat(generateToken.getParameterName()).isEqualTo(parameterName);
assertThat(generateToken.getToken()).isNotEmpty();
}
@Test
public void saveToken() {
CsrfToken token = this.repository.generateToken(this.request);
this.repository.saveToken(token, this.request, this.response);
Cookie tokenCookie = this.response
.getCookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME);
assertThat(tokenCookie.getMaxAge()).isEqualTo(-1);
assertThat(tokenCookie.getName())
.isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME);
assertThat(tokenCookie.getPath()).isEqualTo(this.request.getContextPath());
assertThat(tokenCookie.getSecure()).isEqualTo(this.request.isSecure());
assertThat(tokenCookie.getValue()).isEqualTo(token.getToken());
}
@Test
public void saveTokenSecure() {
this.request.setSecure(true);
CsrfToken token = this.repository.generateToken(this.request);
this.repository.saveToken(token, this.request, this.response);
Cookie tokenCookie = this.response
.getCookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME);
assertThat(tokenCookie.getSecure()).isTrue();
}
@Test
public void saveTokenNull() {
this.request.setSecure(true);
this.repository.saveToken(null, this.request, this.response);
Cookie tokenCookie = this.response
.getCookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME);
assertThat(tokenCookie.getMaxAge()).isEqualTo(0);
assertThat(tokenCookie.getName())
.isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME);
assertThat(tokenCookie.getPath()).isEqualTo(this.request.getContextPath());
assertThat(tokenCookie.getSecure()).isEqualTo(this.request.isSecure());
assertThat(tokenCookie.getValue()).isEmpty();
}
@Test
public void loadTokenNoCookiesNull() {
assertThat(this.repository.loadToken(this.request)).isNull();
}
@Test
public void loadTokenCookieIncorrectNameNull() {
this.request.setCookies(new Cookie("other", "name"));
assertThat(this.repository.loadToken(this.request)).isNull();
}
@Test
public void loadTokenCookieValueEmptyString() {
this.request.setCookies(
new Cookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME, ""));
assertThat(this.repository.loadToken(this.request)).isNull();
}
@Test
public void loadToken() {
CsrfToken generateToken = this.repository.generateToken(this.request);
this.request
.setCookies(new Cookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME,
generateToken.getToken()));
CsrfToken loadToken = this.repository.loadToken(this.request);
assertThat(loadToken).isNotNull();
assertThat(loadToken.getHeaderName()).isEqualTo(generateToken.getHeaderName());
assertThat(loadToken.getParameterName())
.isEqualTo(generateToken.getParameterName());
assertThat(loadToken.getToken()).isNotEmpty();
}
@Test
public void loadTokenCustom() {
String cookieName = "cookieName";
String value = "value";
String headerName = "headerName";
String parameterName = "paramName";
this.repository.setHeaderName(headerName);
this.repository.setParameterName(parameterName);
this.repository.setCookieName(cookieName);
this.request.setCookies(new Cookie(cookieName, value));
CsrfToken loadToken = this.repository.loadToken(this.request);
assertThat(loadToken).isNotNull();
assertThat(loadToken.getHeaderName()).isEqualTo(headerName);
assertThat(loadToken.getParameterName()).isEqualTo(parameterName);
assertThat(loadToken.getToken()).isEqualTo(value);
}
@Test(expected = IllegalArgumentException.class)
public void setCookieNameNullIllegalArgumentException() {
this.repository.setCookieName(null);
}
@Test(expected = IllegalArgumentException.class)
public void setParameterNameNullIllegalArgumentException() {
this.repository.setParameterName(null);
}
@Test(expected = IllegalArgumentException.class)
public void setHeaderNameNullIllegalArgumentException() {
this.repository.setHeaderName(null);
}
}

View File

@ -15,13 +15,6 @@
*/ */
package org.springframework.security.web.csrf; package org.springframework.security.web.csrf;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
@ -30,10 +23,18 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/** /**
* @author Rob Winch * @author Rob Winch
* *
@ -55,11 +56,12 @@ public class CsrfAuthenticationStrategyTests {
@Before @Before
public void setup() { public void setup() {
request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse();
response = new MockHttpServletResponse(); this.request = new MockHttpServletRequest();
strategy = new CsrfAuthenticationStrategy(csrfTokenRepository); this.request.setAttribute(HttpServletResponse.class.getName(), this.response);
existingToken = new DefaultCsrfToken("_csrf", "_csrf", "1"); this.strategy = new CsrfAuthenticationStrategy(this.csrfTokenRepository);
generatedToken = new DefaultCsrfToken("_csrf", "_csrf", "2"); this.existingToken = new DefaultCsrfToken("_csrf", "_csrf", "1");
this.generatedToken = new DefaultCsrfToken("_csrf", "_csrf", "2");
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
@ -69,51 +71,61 @@ public class CsrfAuthenticationStrategyTests {
@Test @Test
public void logoutRemovesCsrfTokenAndSavesNew() { public void logoutRemovesCsrfTokenAndSavesNew() {
when(csrfTokenRepository.loadToken(request)).thenReturn(existingToken); when(this.csrfTokenRepository.loadToken(this.request))
when(csrfTokenRepository.generateToken(request)).thenReturn(generatedToken); .thenReturn(this.existingToken);
strategy.onAuthentication(new TestingAuthenticationToken("user", "password", when(this.csrfTokenRepository.generateToken(this.request))
"ROLE_USER"), request, response); .thenReturn(this.generatedToken);
this.strategy.onAuthentication(
new TestingAuthenticationToken("user", "password", "ROLE_USER"),
this.request, this.response);
verify(csrfTokenRepository).saveToken(null, request, response); verify(this.csrfTokenRepository).saveToken(null, this.request, this.response);
verify(csrfTokenRepository, never()).saveToken(eq(generatedToken), verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken),
any(HttpServletRequest.class), any(HttpServletResponse.class)); any(HttpServletRequest.class), any(HttpServletResponse.class));
// SEC-2404, SEC-2832 // SEC-2404, SEC-2832
CsrfToken tokenInRequest = (CsrfToken) request.getAttribute(CsrfToken.class CsrfToken tokenInRequest = (CsrfToken) this.request
.getName()); .getAttribute(CsrfToken.class.getName());
assertThat(tokenInRequest.getToken()).isSameAs(generatedToken.getToken()); assertThat(tokenInRequest.getToken()).isSameAs(this.generatedToken.getToken());
assertThat(tokenInRequest.getHeaderName()).isSameAs( assertThat(tokenInRequest.getHeaderName())
generatedToken.getHeaderName()); .isSameAs(this.generatedToken.getHeaderName());
assertThat(tokenInRequest.getParameterName()).isSameAs( assertThat(tokenInRequest.getParameterName())
generatedToken.getParameterName()); .isSameAs(this.generatedToken.getParameterName());
assertThat(request.getAttribute(generatedToken.getParameterName())).isSameAs( assertThat(this.request.getAttribute(this.generatedToken.getParameterName()))
tokenInRequest); .isSameAs(tokenInRequest);
} }
// SEC-2872 // SEC-2872
@Test @Test
public void delaySavingCsrf() { public void delaySavingCsrf() {
when(csrfTokenRepository.loadToken(request)).thenReturn(existingToken); this.strategy = new CsrfAuthenticationStrategy(
when(csrfTokenRepository.generateToken(request)).thenReturn(generatedToken); new LazyCsrfTokenRepository(this.csrfTokenRepository));
strategy.onAuthentication(new TestingAuthenticationToken("user", "password",
"ROLE_USER"), request, response);
verify(csrfTokenRepository).saveToken(null, request, response); when(this.csrfTokenRepository.loadToken(this.request))
verify(csrfTokenRepository, never()).saveToken(eq(generatedToken), .thenReturn(this.existingToken);
when(this.csrfTokenRepository.generateToken(this.request))
.thenReturn(this.generatedToken);
this.strategy.onAuthentication(
new TestingAuthenticationToken("user", "password", "ROLE_USER"),
this.request, this.response);
verify(this.csrfTokenRepository).saveToken(null, this.request, this.response);
verify(this.csrfTokenRepository, never()).saveToken(eq(this.generatedToken),
any(HttpServletRequest.class), any(HttpServletResponse.class)); any(HttpServletRequest.class), any(HttpServletResponse.class));
CsrfToken tokenInRequest = (CsrfToken) request.getAttribute(CsrfToken.class CsrfToken tokenInRequest = (CsrfToken) this.request
.getName()); .getAttribute(CsrfToken.class.getName());
tokenInRequest.getToken(); tokenInRequest.getToken();
verify(csrfTokenRepository).saveToken(eq(generatedToken), verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken),
any(HttpServletRequest.class), any(HttpServletResponse.class)); any(HttpServletRequest.class), any(HttpServletResponse.class));
} }
@Test @Test
public void logoutRemovesNoActionIfNullToken() { public void logoutRemovesNoActionIfNullToken() {
strategy.onAuthentication(new TestingAuthenticationToken("user", "password", this.strategy.onAuthentication(
"ROLE_USER"), request, response); new TestingAuthenticationToken("user", "password", "ROLE_USER"),
this.request, this.response);
verify(csrfTokenRepository, never()).saveToken(any(CsrfToken.class), verify(this.csrfTokenRepository, never()).saveToken(any(CsrfToken.class),
any(HttpServletRequest.class), any(HttpServletResponse.class)); any(HttpServletRequest.class), any(HttpServletResponse.class));
} }
} }

View File

@ -15,14 +15,6 @@
*/ */
package org.springframework.security.web.csrf; package org.springframework.security.web.csrf;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -38,11 +30,21 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
/** /**
* @author Rob Winch * @author Rob Winch
* *
@ -67,16 +69,21 @@ public class CsrfFilterTests {
@Before @Before
public void setup() { public void setup() {
token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue");
resetRequestResponse(); resetRequestResponse();
filter = new CsrfFilter(tokenRepository); this.filter = createCsrfFilter(this.tokenRepository);
filter.setRequireCsrfProtectionMatcher(requestMatcher); }
filter.setAccessDeniedHandler(deniedHandler);
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
CsrfFilter filter = new CsrfFilter(repository);
filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
filter.setAccessDeniedHandler(this.deniedHandler);
return filter;
} }
private void resetRequestResponse() { private void resetRequestResponse() {
request = new MockHttpServletRequest(); this.request = new MockHttpServletRequest();
response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
@ -86,282 +93,319 @@ public class CsrfFilterTests {
// SEC-2276 // SEC-2276
@Test @Test
public void doFilterDoesNotSaveCsrfTokenUntilAccessed() throws ServletException, public void doFilterDoesNotSaveCsrfTokenUntilAccessed()
IOException { throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(false); this.filter = createCsrfFilter(new LazyCsrfTokenRepository(this.tokenRepository));
when(tokenRepository.generateToken(request)).thenReturn(token); when(this.requestMatcher.matches(this.request)).thenReturn(false);
when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token);
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
CsrfToken attrToken = (CsrfToken) request.getAttribute(token.getParameterName()); CsrfToken attrToken = (CsrfToken) this.request
.getAttribute(this.token.getParameterName());
// no CsrfToken should have been saved yet // no CsrfToken should have been saved yet
verify(tokenRepository, times(0)).saveToken(any(CsrfToken.class), verify(this.tokenRepository, times(0)).saveToken(any(CsrfToken.class),
any(HttpServletRequest.class), any(HttpServletResponse.class)); any(HttpServletRequest.class), any(HttpServletResponse.class));
verify(filterChain).doFilter(request, response); verify(this.filterChain).doFilter(this.request, this.response);
// access the token // access the token
attrToken.getToken(); attrToken.getToken();
// now the CsrfToken should have been saved // now the CsrfToken should have been saved
verify(tokenRepository).saveToken(eq(token), any(HttpServletRequest.class), verify(this.tokenRepository).saveToken(eq(this.token),
any(HttpServletResponse.class)); any(HttpServletRequest.class), any(HttpServletResponse.class));
} }
@Test @Test
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { public void doFilterAccessDeniedNoTokenPresent()
when(requestMatcher.matches(request)).thenReturn(true); throws ServletException, IOException {
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); assertThat(this.request.getAttribute(this.token.getParameterName()))
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
verify(deniedHandler).handle(eq(request), eq(response), verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
any(InvalidCsrfTokenException.class)); any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain); verifyZeroInteractions(this.filterChain);
} }
@Test @Test
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, public void doFilterAccessDeniedIncorrectTokenPresent()
IOException { throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(true); when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
request.setParameter(token.getParameterName(), token.getToken() + " INVALID"); this.request.setParameter(this.token.getParameterName(),
this.token.getToken() + " INVALID");
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); assertThat(this.request.getAttribute(this.token.getParameterName()))
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
verify(deniedHandler).handle(eq(request), eq(response), verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
any(InvalidCsrfTokenException.class)); any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain); verifyZeroInteractions(this.filterChain);
} }
@Test @Test
public void doFilterAccessDeniedIncorrectTokenPresentHeader() public void doFilterAccessDeniedIncorrectTokenPresentHeader()
throws ServletException, IOException { throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(true); when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
request.addHeader(token.getHeaderName(), token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(),
this.token.getToken() + " INVALID");
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); assertThat(this.request.getAttribute(this.token.getParameterName()))
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
verify(deniedHandler).handle(eq(request), eq(response), verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
any(InvalidCsrfTokenException.class)); any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain); verifyZeroInteractions(this.filterChain);
} }
@Test @Test
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
throws ServletException, IOException { throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(true); when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
request.setParameter(token.getParameterName(), token.getToken()); this.request.setParameter(this.token.getParameterName(), this.token.getToken());
request.addHeader(token.getHeaderName(), token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(),
this.token.getToken() + " INVALID");
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); assertThat(this.request.getAttribute(this.token.getParameterName()))
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
verify(deniedHandler).handle(eq(request), eq(response), verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
any(InvalidCsrfTokenException.class)); any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain); verifyZeroInteractions(this.filterChain);
} }
@Test @Test
public void doFilterNotCsrfRequestExistingToken() throws ServletException, public void doFilterNotCsrfRequestExistingToken()
IOException { throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(false); when(this.requestMatcher.matches(this.request)).thenReturn(false);
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); assertThat(this.request.getAttribute(this.token.getParameterName()))
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
verify(filterChain).doFilter(request, response); verify(this.filterChain).doFilter(this.request, this.response);
verifyZeroInteractions(deniedHandler); verifyZeroInteractions(this.deniedHandler);
} }
@Test @Test
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, public void doFilterNotCsrfRequestGenerateToken()
IOException { throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(false); when(this.requestMatcher.matches(this.request)).thenReturn(false);
when(tokenRepository.generateToken(request)).thenReturn(token); when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token);
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertToken(request.getAttribute(token.getParameterName())).isEqualTo(token); assertToken(this.request.getAttribute(this.token.getParameterName()))
assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
verify(filterChain).doFilter(request, response); verify(this.filterChain).doFilter(this.request, this.response);
verifyZeroInteractions(deniedHandler); verifyZeroInteractions(this.deniedHandler);
} }
@Test @Test
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, public void doFilterIsCsrfRequestExistingTokenHeader()
IOException { throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(true); when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
request.addHeader(token.getHeaderName(), token.getToken()); this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); assertThat(this.request.getAttribute(this.token.getParameterName()))
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
verify(filterChain).doFilter(request, response); verify(this.filterChain).doFilter(this.request, this.response);
verifyZeroInteractions(deniedHandler); verifyZeroInteractions(this.deniedHandler);
} }
@Test @Test
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
throws ServletException, IOException { throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(true); when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
request.setParameter(token.getParameterName(), token.getToken() + " INVALID"); this.request.setParameter(this.token.getParameterName(),
request.addHeader(token.getHeaderName(), token.getToken()); this.token.getToken() + " INVALID");
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); assertThat(this.request.getAttribute(this.token.getParameterName()))
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
verify(filterChain).doFilter(request, response); verify(this.filterChain).doFilter(this.request, this.response);
verifyZeroInteractions(deniedHandler); verifyZeroInteractions(this.deniedHandler);
} }
@Test @Test
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { public void doFilterIsCsrfRequestExistingToken()
when(requestMatcher.matches(request)).thenReturn(true); throws ServletException, IOException {
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.requestMatcher.matches(this.request)).thenReturn(true);
request.setParameter(token.getParameterName(), token.getToken()); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); assertThat(this.request.getAttribute(this.token.getParameterName()))
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
verify(filterChain).doFilter(request, response); verify(this.filterChain).doFilter(this.request, this.response);
verifyZeroInteractions(deniedHandler); verifyZeroInteractions(this.deniedHandler);
verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class),
any(HttpServletRequest.class), any(HttpServletResponse.class));
} }
@Test @Test
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { public void doFilterIsCsrfRequestGenerateToken()
when(requestMatcher.matches(request)).thenReturn(true); throws ServletException, IOException {
when(tokenRepository.generateToken(request)).thenReturn(token); when(this.requestMatcher.matches(this.request)).thenReturn(true);
request.setParameter(token.getParameterName(), token.getToken()); when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertToken(request.getAttribute(token.getParameterName())).isEqualTo(token); assertToken(this.request.getAttribute(this.token.getParameterName()))
assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
verify(filterChain).doFilter(request, response); // LazyCsrfTokenRepository requires the response as an attribute
verifyZeroInteractions(deniedHandler); assertThat(this.request.getAttribute(HttpServletResponse.class.getName()))
.isEqualTo(this.response);
verify(this.filterChain).doFilter(this.request, this.response);
verify(this.tokenRepository).saveToken(this.token, this.request, this.response);
verifyZeroInteractions(this.deniedHandler);
} }
@Test @Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods()
throws ServletException, IOException { throws ServletException, IOException {
filter = new CsrfFilter(tokenRepository); this.filter = new CsrfFilter(this.tokenRepository);
filter.setAccessDeniedHandler(deniedHandler); this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
resetRequestResponse(); resetRequestResponse();
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
request.setMethod(method); this.request.setMethod(method);
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
verify(filterChain).doFilter(request, response); verify(this.filterChain).doFilter(this.request, this.response);
verifyZeroInteractions(deniedHandler); verifyZeroInteractions(this.deniedHandler);
} }
} }
/** /**
* SEC-2292 Should not allow other cases through since spec states HTTP method is case * SEC-2292 Should not allow other cases through since spec states HTTP method is case
* sensitive http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.1 * sensitive http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.1
* @throws Exception if an error occurs
* *
* @throws ServletException
* @throws IOException
*/ */
@Test @Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive()
throws ServletException, IOException { throws Exception {
filter = new CsrfFilter(tokenRepository); this.filter = new CsrfFilter(this.tokenRepository);
filter.setAccessDeniedHandler(deniedHandler); this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) { for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
resetRequestResponse(); resetRequestResponse();
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
request.setMethod(method); this.request.setMethod(method);
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
verify(deniedHandler).handle(eq(request), eq(response), verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
any(InvalidCsrfTokenException.class)); any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain); verifyZeroInteractions(this.filterChain);
} }
} }
@Test @Test
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods()
throws ServletException, IOException { throws ServletException, IOException {
filter = new CsrfFilter(tokenRepository); this.filter = new CsrfFilter(this.tokenRepository);
filter.setAccessDeniedHandler(deniedHandler); this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) { for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
resetRequestResponse(); resetRequestResponse();
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
request.setMethod(method); this.request.setMethod(method);
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
verify(deniedHandler).handle(eq(request), eq(response), verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
any(InvalidCsrfTokenException.class)); any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain); verifyZeroInteractions(this.filterChain);
} }
} }
@Test @Test
public void doFilterDefaultAccessDenied() throws ServletException, IOException { public void doFilterDefaultAccessDenied() throws ServletException, IOException {
filter = new CsrfFilter(tokenRepository); this.filter = new CsrfFilter(this.tokenRepository);
filter.setRequireCsrfProtectionMatcher(requestMatcher); this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
when(requestMatcher.matches(request)).thenReturn(true); when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
filter.doFilter(request, response, filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); assertThat(this.request.getAttribute(this.token.getParameterName()))
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); .isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);
assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
verifyZeroInteractions(filterChain); verifyZeroInteractions(this.filterChain);
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void setRequireCsrfProtectionMatcherNull() { public void setRequireCsrfProtectionMatcherNull() {
filter.setRequireCsrfProtectionMatcher(null); this.filter.setRequireCsrfProtectionMatcher(null);
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void setAccessDeniedHandlerNull() { public void setAccessDeniedHandlerNull() {
filter.setAccessDeniedHandler(null); this.filter.setAccessDeniedHandler(null);
} }
private static final CsrfTokenAssert assertToken(Object token) { private static final CsrfTokenAssert assertToken(Object token) {
return new CsrfTokenAssert((CsrfToken) token); return new CsrfTokenAssert((CsrfToken) token);
} }
private static class CsrfTokenAssert extends private static class CsrfTokenAssert
AbstractObjectAssert<CsrfTokenAssert, CsrfToken> { extends AbstractObjectAssert<CsrfTokenAssert, CsrfToken> {
/** /**
* Creates a new </code>{@link ObjectAssert}</code>. * Creates a new </code>{@link ObjectAssert}</code>.
@ -369,13 +413,14 @@ public class CsrfFilterTests {
* @param actual the target to verify. * @param actual the target to verify.
*/ */
protected CsrfTokenAssert(CsrfToken actual) { protected CsrfTokenAssert(CsrfToken actual) {
super(actual,CsrfTokenAssert.class); super(actual, CsrfTokenAssert.class);
} }
public CsrfTokenAssert isEqualTo(CsrfToken expected) { public CsrfTokenAssert isEqualTo(CsrfToken expected) {
assertThat(actual.getHeaderName()).isEqualTo(expected.getHeaderName()); assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName());
assertThat(actual.getParameterName()).isEqualTo(expected.getParameterName()); assertThat(this.actual.getParameterName())
assertThat(actual.getToken()).isEqualTo(expected.getToken()); .isEqualTo(expected.getParameterName());
assertThat(this.actual.getToken()).isEqualTo(expected.getToken());
return this; return this;
} }
} }

View File

@ -0,0 +1,102 @@
/*
* Copyright 2012-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
/**
* @author Rob Winch
*/
@RunWith(MockitoJUnitRunner.class)
public class LazyCsrfTokenRepositoryTests {
@Mock
CsrfTokenRepository delegate;
@Mock
HttpServletRequest request;
@Mock
HttpServletResponse response;
@InjectMocks
LazyCsrfTokenRepository repository;
DefaultCsrfToken token;
@Before
public void setup() {
this.token = new DefaultCsrfToken("header", "param", "token");
when(this.delegate.generateToken(this.request)).thenReturn(this.token);
when(this.request.getAttribute(HttpServletResponse.class.getName()))
.thenReturn(this.response);
}
@Test(expected = IllegalArgumentException.class)
public void constructNullDelegateThrowsIllegalArgumentException() {
new LazyCsrfTokenRepository(null);
}
@Test(expected = IllegalArgumentException.class)
public void generateTokenNullResponseAttribute() {
this.repository.generateToken(mock(HttpServletRequest.class));
}
@Test
public void generateTokenGetTokenSavesToken() {
CsrfToken newToken = this.repository.generateToken(this.request);
newToken.getToken();
verify(this.delegate).saveToken(this.token, this.request, this.response);
}
@Test
public void saveNonNullDoesNothing() {
this.repository.saveToken(this.token, this.request, this.response);
verifyZeroInteractions(this.delegate);
}
@Test
public void saveNullDelegates() {
this.repository.saveToken(null, this.request, this.response);
verify(this.delegate).saveToken(null, this.request, this.response);
}
@Test
public void loadTokenDelegates() {
when(this.delegate.loadToken(this.request)).thenReturn(this.token);
CsrfToken loadToken = this.repository.loadToken(this.request);
assertThat(loadToken).isSameAs(this.token);
verify(this.delegate).loadToken(this.request);
}
}