SEC-2276: Delay saving CsrfToken until token is accessed
This also removed the CsrfToken from the response headers to prevent the token from being saved. If user's wish to return the CsrfToken in the response headers, they should use the CsrfToken found on the request.
This commit is contained in:
parent
c131fb6379
commit
48283ec004
|
@ -36,6 +36,7 @@ import org.springframework.security.web.access.intercept.FilterSecurityIntercept
|
||||||
import org.springframework.security.web.context.HttpRequestResponseHolder
|
import org.springframework.security.web.context.HttpRequestResponseHolder
|
||||||
import org.springframework.security.web.context.HttpSessionSecurityContextRepository
|
import org.springframework.security.web.context.HttpSessionSecurityContextRepository
|
||||||
import org.springframework.security.web.csrf.CsrfToken
|
import org.springframework.security.web.csrf.CsrfToken
|
||||||
|
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||||
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository
|
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository
|
||||||
|
|
||||||
import spock.lang.AutoCleanup
|
import spock.lang.AutoCleanup
|
||||||
|
@ -69,7 +70,7 @@ abstract class BaseSpringSpec extends Specification {
|
||||||
}
|
}
|
||||||
|
|
||||||
def setupCsrf(csrfTokenValue="BaseSpringSpec_CSRFTOKEN") {
|
def setupCsrf(csrfTokenValue="BaseSpringSpec_CSRFTOKEN") {
|
||||||
csrfToken = new CsrfToken("X-CSRF-TOKEN","_csrf",csrfTokenValue)
|
csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf",csrfTokenValue)
|
||||||
new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request,response)
|
new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request,response)
|
||||||
request.setParameter(csrfToken.parameterName, csrfToken.token)
|
request.setParameter(csrfToken.parameterName, csrfToken.token)
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,8 +79,7 @@ class WebSecurityConfigurerAdapterTests extends BaseSpringSpec {
|
||||||
'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
|
'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
|
||||||
'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
|
'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
|
||||||
'Pragma':'no-cache',
|
'Pragma':'no-cache',
|
||||||
'X-XSS-Protection' : '1; mode=block',
|
'X-XSS-Protection' : '1; mode=block']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@EnableWebSecurity
|
@EnableWebSecurity
|
||||||
|
|
|
@ -49,8 +49,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
|
||||||
'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
|
'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
|
||||||
'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
|
'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
|
||||||
'Pragma':'no-cache',
|
'Pragma':'no-cache',
|
||||||
'X-XSS-Protection' : '1; mode=block',
|
'X-XSS-Protection' : '1; mode=block']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
|
@ -70,8 +69,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
|
||||||
springSecurityFilterChain.doFilter(request,response,chain)
|
springSecurityFilterChain.doFilter(request,response,chain)
|
||||||
then:
|
then:
|
||||||
responseHeaders == ['Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
|
responseHeaders == ['Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
|
||||||
'Pragma':'no-cache',
|
'Pragma':'no-cache']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
|
@ -91,8 +89,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
|
||||||
when:
|
when:
|
||||||
springSecurityFilterChain.doFilter(request,response,chain)
|
springSecurityFilterChain.doFilter(request,response,chain)
|
||||||
then:
|
then:
|
||||||
responseHeaders == ['Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
|
responseHeaders == ['Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
|
@ -111,8 +108,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
|
||||||
when:
|
when:
|
||||||
springSecurityFilterChain.doFilter(request,response,chain)
|
springSecurityFilterChain.doFilter(request,response,chain)
|
||||||
then:
|
then:
|
||||||
responseHeaders == ['Strict-Transport-Security': 'max-age=15768000',
|
responseHeaders == ['Strict-Transport-Security': 'max-age=15768000']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
|
@ -133,8 +129,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
|
||||||
when:
|
when:
|
||||||
springSecurityFilterChain.doFilter(request,response,chain)
|
springSecurityFilterChain.doFilter(request,response,chain)
|
||||||
then:
|
then:
|
||||||
responseHeaders == ['X-Frame-Options': 'SAMEORIGIN',
|
responseHeaders == ['X-Frame-Options': 'SAMEORIGIN']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
|
@ -156,8 +151,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
|
||||||
when:
|
when:
|
||||||
springSecurityFilterChain.doFilter(request,response,chain)
|
springSecurityFilterChain.doFilter(request,response,chain)
|
||||||
then:
|
then:
|
||||||
responseHeaders == ['X-Frame-Options': 'ALLOW-FROM https://example.com',
|
responseHeaders == ['X-Frame-Options': 'ALLOW-FROM https://example.com']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,8 +172,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
|
||||||
when:
|
when:
|
||||||
springSecurityFilterChain.doFilter(request,response,chain)
|
springSecurityFilterChain.doFilter(request,response,chain)
|
||||||
then:
|
then:
|
||||||
responseHeaders == ['X-XSS-Protection': '1; mode=block',
|
responseHeaders == ['X-XSS-Protection': '1; mode=block']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
|
@ -199,8 +192,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
|
||||||
when:
|
when:
|
||||||
springSecurityFilterChain.doFilter(request,response,chain)
|
springSecurityFilterChain.doFilter(request,response,chain)
|
||||||
then:
|
then:
|
||||||
responseHeaders == ['X-XSS-Protection': '1',
|
responseHeaders == ['X-XSS-Protection': '1']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
|
@ -220,8 +212,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
|
||||||
when:
|
when:
|
||||||
springSecurityFilterChain.doFilter(request,response,chain)
|
springSecurityFilterChain.doFilter(request,response,chain)
|
||||||
then:
|
then:
|
||||||
responseHeaders == ['X-Content-Type-Options': 'nosniff',
|
responseHeaders == ['X-Content-Type-Options': 'nosniff']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
|
@ -243,8 +234,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
|
||||||
when:
|
when:
|
||||||
springSecurityFilterChain.doFilter(request,response,chain)
|
springSecurityFilterChain.doFilter(request,response,chain)
|
||||||
then:
|
then:
|
||||||
responseHeaders == ['customHeaderName': 'customHeaderValue',
|
responseHeaders == ['customHeaderName': 'customHeaderValue']
|
||||||
'X-CSRF-TOKEN' : csrfToken.token]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.springframework.security.web.access.AccessDeniedHandler;
|
||||||
import org.springframework.security.web.csrf.CsrfFilter
|
import org.springframework.security.web.csrf.CsrfFilter
|
||||||
import org.springframework.security.web.csrf.CsrfToken;
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
import org.springframework.security.web.csrf.CsrfTokenRepository;
|
import org.springframework.security.web.csrf.CsrfTokenRepository;
|
||||||
|
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||||
import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor
|
import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor
|
||||||
import org.springframework.security.web.util.RequestMatcher
|
import org.springframework.security.web.util.RequestMatcher
|
||||||
|
|
||||||
|
@ -113,7 +114,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
|
||||||
mockBean(CsrfTokenRepository,'repo')
|
mockBean(CsrfTokenRepository,'repo')
|
||||||
createAppContext()
|
createAppContext()
|
||||||
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
|
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
|
||||||
CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
|
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
|
||||||
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
|
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
|
||||||
request.setParameter(token.parameterName,token.token)
|
request.setParameter(token.parameterName,token.token)
|
||||||
request.servletPath = "/some-url"
|
request.servletPath = "/some-url"
|
||||||
|
@ -147,7 +148,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
|
||||||
mockBean(CsrfTokenRepository,'repo')
|
mockBean(CsrfTokenRepository,'repo')
|
||||||
createAppContext()
|
createAppContext()
|
||||||
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
|
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
|
||||||
CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
|
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
|
||||||
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
|
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
|
||||||
request.setParameter(token.parameterName,token.token)
|
request.setParameter(token.parameterName,token.token)
|
||||||
request.servletPath = "/some-url"
|
request.servletPath = "/some-url"
|
||||||
|
@ -200,7 +201,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
|
||||||
mockBean(CsrfTokenRepository,'repo')
|
mockBean(CsrfTokenRepository,'repo')
|
||||||
createAppContext()
|
createAppContext()
|
||||||
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
|
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
|
||||||
CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
|
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
|
||||||
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
|
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
|
||||||
request.setParameter(token.parameterName,token.token)
|
request.setParameter(token.parameterName,token.token)
|
||||||
request.method = "POST"
|
request.method = "POST"
|
||||||
|
@ -223,7 +224,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
|
||||||
mockBean(CsrfTokenRepository,'repo')
|
mockBean(CsrfTokenRepository,'repo')
|
||||||
createAppContext()
|
createAppContext()
|
||||||
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
|
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
|
||||||
CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
|
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
|
||||||
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
|
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
|
||||||
request.setParameter(token.parameterName,token.token)
|
request.setParameter(token.parameterName,token.token)
|
||||||
request.method = "POST"
|
request.method = "POST"
|
||||||
|
@ -244,7 +245,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
|
||||||
mockBean(CsrfTokenRepository,'repo')
|
mockBean(CsrfTokenRepository,'repo')
|
||||||
createAppContext()
|
createAppContext()
|
||||||
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
|
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
|
||||||
CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
|
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
|
||||||
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
|
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
|
||||||
request.setParameter(token.parameterName,token.token)
|
request.setParameter(token.parameterName,token.token)
|
||||||
request.method = "POST"
|
request.method = "POST"
|
||||||
|
|
|
@ -40,7 +40,6 @@ import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.mock.web.MockFilterChain;
|
import org.springframework.mock.web.MockFilterChain;
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.mock.web.MockHttpServletResponse;
|
import org.springframework.mock.web.MockHttpServletResponse;
|
||||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
|
||||||
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
|
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
|
||||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||||
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
||||||
|
@ -96,7 +95,9 @@ public class SessionManagementConfigurerServlet31Tests {
|
||||||
request.setMethod("POST");
|
request.setMethod("POST");
|
||||||
request.setParameter("username", "user");
|
request.setParameter("username", "user");
|
||||||
request.setParameter("password", "password");
|
request.setParameter("password", "password");
|
||||||
CsrfToken token = new HttpSessionCsrfTokenRepository().generateAndSaveToken(request, response);
|
HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository();
|
||||||
|
CsrfToken token = repository.generateToken(request);
|
||||||
|
repository.saveToken(token, request, response);
|
||||||
request.setParameter(token.getParameterName(),token.getToken());
|
request.setParameter(token.getParameterName(),token.getToken());
|
||||||
when(ReflectionUtils.findMethod(HttpServletRequest.class, "changeSessionId")).thenReturn(method);
|
when(ReflectionUtils.findMethod(HttpServletRequest.class, "changeSessionId")).thenReturn(method);
|
||||||
|
|
||||||
|
|
|
@ -70,11 +70,11 @@ public final class CsrfFilter extends OncePerRequestFilter {
|
||||||
throws ServletException, IOException {
|
throws ServletException, IOException {
|
||||||
CsrfToken csrfToken = tokenRepository.loadToken(request);
|
CsrfToken csrfToken = tokenRepository.loadToken(request);
|
||||||
if(csrfToken == null) {
|
if(csrfToken == null) {
|
||||||
csrfToken = tokenRepository.generateAndSaveToken(request, response);
|
CsrfToken generatedToken = tokenRepository.generateToken(request);
|
||||||
|
csrfToken = new SaveOnAccessCsrfToken(tokenRepository, request, response, generatedToken);
|
||||||
}
|
}
|
||||||
request.setAttribute(CsrfToken.class.getName(), csrfToken);
|
request.setAttribute(CsrfToken.class.getName(), csrfToken);
|
||||||
request.setAttribute(csrfToken.getParameterName(), csrfToken);
|
request.setAttribute(csrfToken.getParameterName(), csrfToken);
|
||||||
response.addHeader(csrfToken.getHeaderName(), csrfToken.getToken());
|
|
||||||
|
|
||||||
if(!requireCsrfProtectionMatcher.matches(request)) {
|
if(!requireCsrfProtectionMatcher.matches(request)) {
|
||||||
filterChain.doFilter(request, response);
|
filterChain.doFilter(request, response);
|
||||||
|
@ -128,7 +128,86 @@ public final class CsrfFilter extends OncePerRequestFilter {
|
||||||
this.accessDeniedHandler = accessDeniedHandler;
|
this.accessDeniedHandler = accessDeniedHandler;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class DefaultRequiresCsrfMatcher implements RequestMatcher {
|
@SuppressWarnings("serial")
|
||||||
|
private static final class SaveOnAccessCsrfToken implements CsrfToken {
|
||||||
|
private transient CsrfTokenRepository tokenRepository;
|
||||||
|
private transient HttpServletRequest request;
|
||||||
|
private transient HttpServletResponse response;
|
||||||
|
|
||||||
|
private final CsrfToken delegate;
|
||||||
|
|
||||||
|
public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository,
|
||||||
|
HttpServletRequest request, HttpServletResponse response,
|
||||||
|
CsrfToken delegate) {
|
||||||
|
super();
|
||||||
|
this.tokenRepository = tokenRepository;
|
||||||
|
this.request = request;
|
||||||
|
this.response = response;
|
||||||
|
this.delegate = delegate;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getHeaderName() {
|
||||||
|
return delegate.getHeaderName();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getParameterName() {
|
||||||
|
return delegate.getParameterName();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getToken() {
|
||||||
|
saveTokenIfNecessary();
|
||||||
|
return delegate.getToken();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "SaveOnAccessCsrfToken [delegate=" + delegate + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
final int prime = 31;
|
||||||
|
int result = 1;
|
||||||
|
result = prime * result
|
||||||
|
+ ((delegate == null) ? 0 : delegate.hashCode());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj)
|
||||||
|
return true;
|
||||||
|
if (obj == null)
|
||||||
|
return false;
|
||||||
|
if (getClass() != obj.getClass())
|
||||||
|
return false;
|
||||||
|
SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj;
|
||||||
|
if (delegate == null) {
|
||||||
|
if (other.delegate != null)
|
||||||
|
return false;
|
||||||
|
} else if (!delegate.equals(other.delegate))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void saveTokenIfNecessary() {
|
||||||
|
if(this.tokenRepository == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
synchronized(this) {
|
||||||
|
if(tokenRepository != null) {
|
||||||
|
this.tokenRepository.saveToken(delegate, request, response);
|
||||||
|
this.tokenRepository = null;
|
||||||
|
this.request = null;
|
||||||
|
this.response = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
|
||||||
private Pattern allowedMethods = Pattern.compile("^(GET|HEAD|TRACE|OPTIONS)$");
|
private Pattern allowedMethods = Pattern.compile("^(GET|HEAD|TRACE|OPTIONS)$");
|
||||||
|
|
||||||
/* (non-Javadoc)
|
/* (non-Javadoc)
|
||||||
|
|
|
@ -17,37 +17,16 @@ package org.springframework.security.web.csrf;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
import org.springframework.util.Assert;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A CSRF token that is used to protect against CSRF attacks.
|
* Provides the information about an expected CSRF token.
|
||||||
|
*
|
||||||
|
* @see DefaultCsrfToken
|
||||||
*
|
*
|
||||||
* @author Rob Winch
|
* @author Rob Winch
|
||||||
* @since 3.2
|
* @since 3.2
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
@SuppressWarnings("serial")
|
public interface CsrfToken extends Serializable {
|
||||||
public final class CsrfToken implements Serializable {
|
|
||||||
|
|
||||||
private final String token;
|
|
||||||
|
|
||||||
private final String parameterName;
|
|
||||||
|
|
||||||
private final String headerName;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a new instance
|
|
||||||
* @param headerName the HTTP header name to use
|
|
||||||
* @param parameterName the HTTP parameter name to use
|
|
||||||
* @param token the value of the token (i.e. expected value of the HTTP parameter of parametername).
|
|
||||||
*/
|
|
||||||
public CsrfToken(String headerName, String parameterName, String token) {
|
|
||||||
Assert.hasLength(headerName, "headerName cannot be null or empty");
|
|
||||||
Assert.hasLength(parameterName, "parameterName cannot be null or empty");
|
|
||||||
Assert.hasLength(token, "token cannot be null or empty");
|
|
||||||
this.headerName = headerName;
|
|
||||||
this.parameterName = parameterName;
|
|
||||||
this.token = token;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the HTTP header that the CSRF is populated on the response and can
|
* Gets the HTTP header that the CSRF is populated on the response and can
|
||||||
|
@ -56,23 +35,18 @@ public final class CsrfToken implements Serializable {
|
||||||
* @return the HTTP header that the CSRF is populated on the response and
|
* @return the HTTP header that the CSRF is populated on the response and
|
||||||
* can be placed on requests instead of the parameter
|
* can be placed on requests instead of the parameter
|
||||||
*/
|
*/
|
||||||
public String getHeaderName() {
|
String getHeaderName();
|
||||||
return headerName;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the HTTP parameter name that should contain the token. Cannot be null.
|
* Gets the HTTP parameter name that should contain the token. Cannot be null.
|
||||||
* @return the HTTP parameter name that should contain the token.
|
* @return the HTTP parameter name that should contain the token.
|
||||||
*/
|
*/
|
||||||
public String getParameterName() {
|
String getParameterName();
|
||||||
return parameterName;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the token value. Cannot be null.
|
* Gets the token value. Cannot be null.
|
||||||
* @return the token value
|
* @return the token value
|
||||||
*/
|
*/
|
||||||
public String getToken() {
|
String getToken();
|
||||||
return token;
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -33,17 +33,14 @@ import javax.servlet.http.HttpSession;
|
||||||
public interface CsrfTokenRepository {
|
public interface CsrfTokenRepository {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generates and saves the expected {@link CsrfToken}
|
* Generates a {@link CsrfToken}
|
||||||
*
|
*
|
||||||
* @param request
|
* @param request
|
||||||
* the {@link HttpServletRequest} to use
|
* the {@link HttpServletRequest} to use
|
||||||
* @param response
|
* @return the {@link CsrfToken} that was generated. Cannot be
|
||||||
* the {@link HttpServletResponse} to use
|
|
||||||
* @return the {@link CsrfToken} that was generated and saved. Cannot be
|
|
||||||
* null.
|
* null.
|
||||||
*/
|
*/
|
||||||
CsrfToken generateAndSaveToken(HttpServletRequest request,
|
CsrfToken generateToken(HttpServletRequest request);
|
||||||
HttpServletResponse response);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Saves the {@link CsrfToken} using the {@link HttpServletRequest} and
|
* Saves the {@link CsrfToken} using the {@link HttpServletRequest} and
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2002-2013 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.springframework.security.web.csrf;
|
||||||
|
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A CSRF token that is used to protect against CSRF attacks.
|
||||||
|
*
|
||||||
|
* @author Rob Winch
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
@SuppressWarnings("serial")
|
||||||
|
public final class DefaultCsrfToken implements CsrfToken {
|
||||||
|
|
||||||
|
private final String token;
|
||||||
|
|
||||||
|
private final String parameterName;
|
||||||
|
|
||||||
|
private final String headerName;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new instance
|
||||||
|
* @param headerName the HTTP header name to use
|
||||||
|
* @param parameterName the HTTP parameter name to use
|
||||||
|
* @param token the value of the token (i.e. expected value of the HTTP parameter of parametername).
|
||||||
|
*/
|
||||||
|
public DefaultCsrfToken(String headerName, String parameterName, String token) {
|
||||||
|
Assert.hasLength(headerName, "headerName cannot be null or empty");
|
||||||
|
Assert.hasLength(parameterName, "parameterName cannot be null or empty");
|
||||||
|
Assert.hasLength(token, "token cannot be null or empty");
|
||||||
|
this.headerName = headerName;
|
||||||
|
this.parameterName = parameterName;
|
||||||
|
this.token = token;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (non-Javadoc)
|
||||||
|
* @see org.springframework.security.web.csrf.CsrfToken#getHeaderName()
|
||||||
|
*/
|
||||||
|
public String getHeaderName() {
|
||||||
|
return headerName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (non-Javadoc)
|
||||||
|
* @see org.springframework.security.web.csrf.CsrfToken#getParameterName()
|
||||||
|
*/
|
||||||
|
public String getParameterName() {
|
||||||
|
return parameterName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (non-Javadoc)
|
||||||
|
* @see org.springframework.security.web.csrf.CsrfToken#getToken()
|
||||||
|
*/
|
||||||
|
public String getToken() {
|
||||||
|
return token;
|
||||||
|
}
|
||||||
|
}
|
|
@ -63,14 +63,12 @@ public final class HttpSessionCsrfTokenRepository implements CsrfTokenRepository
|
||||||
return (CsrfToken) request.getSession().getAttribute(sessionAttributeName);
|
return (CsrfToken) request.getSession().getAttribute(sessionAttributeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (non-Javadoc)
|
/*
|
||||||
* @see org.springframework.security.web.csrf.CsrfTokenRepository#generateNewToken(javax.servlet.http.HttpServletRequest, javax.servlet.http.HttpServletResponse)
|
* (non-Javadoc)
|
||||||
|
* @see org.springframework.security.web.csrf.CsrfTokenRepository#generateToken(javax.servlet.http.HttpServletRequest)
|
||||||
*/
|
*/
|
||||||
public CsrfToken generateAndSaveToken(HttpServletRequest request,
|
public CsrfToken generateToken(HttpServletRequest request) {
|
||||||
HttpServletResponse response) {
|
return new DefaultCsrfToken(headerName, parameterName, createNewToken());
|
||||||
CsrfToken token = new CsrfToken(headerName, parameterName, createNewToken());
|
|
||||||
saveToken(token, request, response);
|
|
||||||
return token;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.springframework.security.web.csrf;
|
||||||
import static org.fest.assertions.Assertions.assertThat;
|
import static org.fest.assertions.Assertions.assertThat;
|
||||||
import static org.mockito.Matchers.any;
|
import static org.mockito.Matchers.any;
|
||||||
import static org.mockito.Matchers.eq;
|
import static org.mockito.Matchers.eq;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.mockito.Mockito.verifyZeroInteractions;
|
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
@ -27,8 +28,11 @@ import java.util.Arrays;
|
||||||
|
|
||||||
import javax.servlet.FilterChain;
|
import javax.servlet.FilterChain;
|
||||||
import javax.servlet.ServletException;
|
import javax.servlet.ServletException;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
|
||||||
|
import org.fest.assertions.GenericAssert;
|
||||||
|
import org.fest.assertions.ObjectAssert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
|
@ -59,12 +63,12 @@ public class CsrfFilterTests {
|
||||||
private MockHttpServletResponse response;
|
private MockHttpServletResponse response;
|
||||||
private CsrfToken token;
|
private CsrfToken token;
|
||||||
|
|
||||||
|
|
||||||
private CsrfFilter filter;
|
private CsrfFilter filter;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setup() {
|
public void setup() {
|
||||||
token = new CsrfToken("headerName","paramName", "csrfTokenValue");
|
token = new DefaultCsrfToken("headerName", "paramName",
|
||||||
|
"csrfTokenValue");
|
||||||
resetRequestResponse();
|
resetRequestResponse();
|
||||||
filter = new CsrfFilter(tokenRepository);
|
filter = new CsrfFilter(tokenRepository);
|
||||||
filter.setRequireCsrfProtectionMatcher(requestMatcher);
|
filter.setRequireCsrfProtectionMatcher(requestMatcher);
|
||||||
|
@ -81,171 +85,221 @@ public class CsrfFilterTests {
|
||||||
new CsrfFilter(null);
|
new CsrfFilter(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SEC-2276
|
||||||
@Test
|
@Test
|
||||||
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
|
public void doFilterDoesNotSaveCsrfTokenUntilAccessed() throws ServletException,
|
||||||
|
IOException {
|
||||||
|
when(requestMatcher.matches(request)).thenReturn(false);
|
||||||
|
when(tokenRepository.generateToken(request)).thenReturn(token);
|
||||||
|
|
||||||
|
filter.doFilter(request, response, filterChain);
|
||||||
|
CsrfToken attrToken = (CsrfToken) request.getAttribute(token.getParameterName());
|
||||||
|
|
||||||
|
// no CsrfToken should have been saved yet
|
||||||
|
verify(tokenRepository,times(0)).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
verify(filterChain).doFilter(request, response);
|
||||||
|
|
||||||
|
// access the token
|
||||||
|
attrToken.getToken();
|
||||||
|
|
||||||
|
// now the CsrfToken should have been saved
|
||||||
|
verify(tokenRepository).saveToken(eq(token), any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterAccessDeniedNoTokenPresent() throws ServletException,
|
||||||
|
IOException {
|
||||||
when(requestMatcher.matches(request)).thenReturn(true);
|
when(requestMatcher.matches(request)).thenReturn(true);
|
||||||
when(tokenRepository.loadToken(request)).thenReturn(token);
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
|
verify(deniedHandler).handle(eq(request), eq(response),
|
||||||
|
any(InvalidCsrfTokenException.class));
|
||||||
verifyZeroInteractions(filterChain);
|
verifyZeroInteractions(filterChain);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
|
public void doFilterAccessDeniedIncorrectTokenPresent()
|
||||||
|
throws ServletException, IOException {
|
||||||
when(requestMatcher.matches(request)).thenReturn(true);
|
when(requestMatcher.matches(request)).thenReturn(true);
|
||||||
when(tokenRepository.loadToken(request)).thenReturn(token);
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
||||||
request.setParameter(token.getParameterName(), token.getToken()+ " INVALID");
|
request.setParameter(token.getParameterName(), token.getToken()
|
||||||
|
+ " INVALID");
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
|
verify(deniedHandler).handle(eq(request), eq(response),
|
||||||
|
any(InvalidCsrfTokenException.class));
|
||||||
verifyZeroInteractions(filterChain);
|
verifyZeroInteractions(filterChain);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
|
public void doFilterAccessDeniedIncorrectTokenPresentHeader()
|
||||||
|
throws ServletException, IOException {
|
||||||
when(requestMatcher.matches(request)).thenReturn(true);
|
when(requestMatcher.matches(request)).thenReturn(true);
|
||||||
when(tokenRepository.loadToken(request)).thenReturn(token);
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
||||||
request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID");
|
request.addHeader(token.getHeaderName(), token.getToken() + " INVALID");
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
|
verify(deniedHandler).handle(eq(request), eq(response),
|
||||||
|
any(InvalidCsrfTokenException.class));
|
||||||
verifyZeroInteractions(filterChain);
|
verifyZeroInteractions(filterChain);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException {
|
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
|
||||||
|
throws ServletException, IOException {
|
||||||
when(requestMatcher.matches(request)).thenReturn(true);
|
when(requestMatcher.matches(request)).thenReturn(true);
|
||||||
when(tokenRepository.loadToken(request)).thenReturn(token);
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
||||||
request.setParameter(token.getParameterName(), token.getToken());
|
request.setParameter(token.getParameterName(), token.getToken());
|
||||||
request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID");
|
request.addHeader(token.getHeaderName(), token.getToken() + " INVALID");
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
|
verify(deniedHandler).handle(eq(request), eq(response),
|
||||||
|
any(InvalidCsrfTokenException.class));
|
||||||
verifyZeroInteractions(filterChain);
|
verifyZeroInteractions(filterChain);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
|
public void doFilterNotCsrfRequestExistingToken() throws ServletException,
|
||||||
|
IOException {
|
||||||
when(requestMatcher.matches(request)).thenReturn(false);
|
when(requestMatcher.matches(request)).thenReturn(false);
|
||||||
when(tokenRepository.loadToken(request)).thenReturn(token);
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
verify(filterChain).doFilter(request, response);
|
verify(filterChain).doFilter(request, response);
|
||||||
verifyZeroInteractions(deniedHandler);
|
verifyZeroInteractions(deniedHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
|
public void doFilterNotCsrfRequestGenerateToken() throws ServletException,
|
||||||
|
IOException {
|
||||||
when(requestMatcher.matches(request)).thenReturn(false);
|
when(requestMatcher.matches(request)).thenReturn(false);
|
||||||
when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token);
|
when(tokenRepository.generateToken(request))
|
||||||
|
.thenReturn(token);
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertToken(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
verify(filterChain).doFilter(request, response);
|
verify(filterChain).doFilter(request, response);
|
||||||
verifyZeroInteractions(deniedHandler);
|
verifyZeroInteractions(deniedHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
|
public void doFilterIsCsrfRequestExistingTokenHeader()
|
||||||
|
throws ServletException, IOException {
|
||||||
when(requestMatcher.matches(request)).thenReturn(true);
|
when(requestMatcher.matches(request)).thenReturn(true);
|
||||||
when(tokenRepository.loadToken(request)).thenReturn(token);
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
||||||
request.addHeader(token.getHeaderName(), token.getToken());
|
request.addHeader(token.getHeaderName(), token.getToken());
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
verify(filterChain).doFilter(request, response);
|
verify(filterChain).doFilter(request, response);
|
||||||
verifyZeroInteractions(deniedHandler);
|
verifyZeroInteractions(deniedHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException {
|
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
|
||||||
|
throws ServletException, IOException {
|
||||||
when(requestMatcher.matches(request)).thenReturn(true);
|
when(requestMatcher.matches(request)).thenReturn(true);
|
||||||
when(tokenRepository.loadToken(request)).thenReturn(token);
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
||||||
request.setParameter(token.getParameterName(), token.getToken()+ " INVALID");
|
request.setParameter(token.getParameterName(), token.getToken()
|
||||||
|
+ " INVALID");
|
||||||
request.addHeader(token.getHeaderName(), token.getToken());
|
request.addHeader(token.getHeaderName(), token.getToken());
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
verify(filterChain).doFilter(request, response);
|
verify(filterChain).doFilter(request, response);
|
||||||
verifyZeroInteractions(deniedHandler);
|
verifyZeroInteractions(deniedHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
|
public void doFilterIsCsrfRequestExistingToken() throws ServletException,
|
||||||
|
IOException {
|
||||||
when(requestMatcher.matches(request)).thenReturn(true);
|
when(requestMatcher.matches(request)).thenReturn(true);
|
||||||
when(tokenRepository.loadToken(request)).thenReturn(token);
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
||||||
request.setParameter(token.getParameterName(), token.getToken());
|
request.setParameter(token.getParameterName(), token.getToken());
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
verify(filterChain).doFilter(request, response);
|
verify(filterChain).doFilter(request, response);
|
||||||
verifyZeroInteractions(deniedHandler);
|
verifyZeroInteractions(deniedHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
|
public void doFilterIsCsrfRequestGenerateToken() throws ServletException,
|
||||||
|
IOException {
|
||||||
when(requestMatcher.matches(request)).thenReturn(true);
|
when(requestMatcher.matches(request)).thenReturn(true);
|
||||||
when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token);
|
when(tokenRepository.generateToken(request))
|
||||||
|
.thenReturn(token);
|
||||||
request.setParameter(token.getParameterName(), token.getToken());
|
request.setParameter(token.getParameterName(), token.getToken());
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertToken(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
verify(filterChain).doFilter(request, response);
|
verify(filterChain).doFilter(request, response);
|
||||||
verifyZeroInteractions(deniedHandler);
|
verifyZeroInteractions(deniedHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods()
|
||||||
|
throws ServletException, IOException {
|
||||||
filter = new CsrfFilter(tokenRepository);
|
filter = new CsrfFilter(tokenRepository);
|
||||||
filter.setAccessDeniedHandler(deniedHandler);
|
filter.setAccessDeniedHandler(deniedHandler);
|
||||||
|
|
||||||
for(String method : Arrays.asList("GET","TRACE", "OPTIONS", "HEAD")) {
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
|
||||||
resetRequestResponse();
|
resetRequestResponse();
|
||||||
when(tokenRepository.loadToken(request)).thenReturn(token);
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
||||||
request.setMethod(method);
|
request.setMethod(method);
|
||||||
|
@ -258,24 +312,28 @@ public class CsrfFilterTests {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
|
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods()
|
||||||
|
throws ServletException, IOException {
|
||||||
filter = new CsrfFilter(tokenRepository);
|
filter = new CsrfFilter(tokenRepository);
|
||||||
filter.setAccessDeniedHandler(deniedHandler);
|
filter.setAccessDeniedHandler(deniedHandler);
|
||||||
|
|
||||||
for(String method : Arrays.asList("POST","PUT", "PATCH", "DELETE", "INVALID")) {
|
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE",
|
||||||
|
"INVALID")) {
|
||||||
resetRequestResponse();
|
resetRequestResponse();
|
||||||
when(tokenRepository.loadToken(request)).thenReturn(token);
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
||||||
request.setMethod(method);
|
request.setMethod(method);
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
|
verify(deniedHandler).handle(eq(request), eq(response),
|
||||||
|
any(InvalidCsrfTokenException.class));
|
||||||
verifyZeroInteractions(filterChain);
|
verifyZeroInteractions(filterChain);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterDefaultAccessDenied() throws ServletException, IOException {
|
public void doFilterDefaultAccessDenied() throws ServletException,
|
||||||
|
IOException {
|
||||||
filter = new CsrfFilter(tokenRepository);
|
filter = new CsrfFilter(tokenRepository);
|
||||||
filter.setRequireCsrfProtectionMatcher(requestMatcher);
|
filter.setRequireCsrfProtectionMatcher(requestMatcher);
|
||||||
when(requestMatcher.matches(request)).thenReturn(true);
|
when(requestMatcher.matches(request)).thenReturn(true);
|
||||||
|
@ -283,11 +341,13 @@ public class CsrfFilterTests {
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
||||||
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
token);
|
||||||
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
||||||
|
token);
|
||||||
|
|
||||||
assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
|
assertThat(response.getStatus()).isEqualTo(
|
||||||
|
HttpServletResponse.SC_FORBIDDEN);
|
||||||
verifyZeroInteractions(filterChain);
|
verifyZeroInteractions(filterChain);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -300,4 +360,29 @@ public class CsrfFilterTests {
|
||||||
public void setAccessDeniedHandlerNull() {
|
public void setAccessDeniedHandlerNull() {
|
||||||
filter.setAccessDeniedHandler(null);
|
filter.setAccessDeniedHandler(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final CsrfTokenAssert assertToken(Object token) {
|
||||||
|
return new CsrfTokenAssert((CsrfToken)token);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class CsrfTokenAssert extends
|
||||||
|
GenericAssert<CsrfTokenAssert, CsrfToken> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new </code>{@link ObjectAssert}</code>.
|
||||||
|
*
|
||||||
|
* @param actual
|
||||||
|
* the target to verify.
|
||||||
|
*/
|
||||||
|
protected CsrfTokenAssert(CsrfToken actual) {
|
||||||
|
super(CsrfTokenAssert.class, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CsrfTokenAssert isEqualTo(CsrfToken expected) {
|
||||||
|
assertThat(actual.getHeaderName()).isEqualTo(expected.getHeaderName());
|
||||||
|
assertThat(actual.getParameterName()).isEqualTo(expected.getParameterName());
|
||||||
|
assertThat(actual.getToken()).isEqualTo(expected.getToken());
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,38 +21,38 @@ import org.junit.Test;
|
||||||
* @author Rob Winch
|
* @author Rob Winch
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public class CsrfTokenTests {
|
public class DefaultCsrfTokenTests {
|
||||||
private final String headerName = "headerName";
|
private final String headerName = "headerName";
|
||||||
private final String parameterName = "parameterName";
|
private final String parameterName = "parameterName";
|
||||||
private final String tokenValue = "tokenValue";
|
private final String tokenValue = "tokenValue";
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void constructorNullHeaderName() {
|
public void constructorNullHeaderName() {
|
||||||
new CsrfToken(null,parameterName, tokenValue);
|
new DefaultCsrfToken(null,parameterName, tokenValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void constructorEmptyHeaderName() {
|
public void constructorEmptyHeaderName() {
|
||||||
new CsrfToken("",parameterName, tokenValue);
|
new DefaultCsrfToken("",parameterName, tokenValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void constructorNullParameterName() {
|
public void constructorNullParameterName() {
|
||||||
new CsrfToken(headerName,null, tokenValue);
|
new DefaultCsrfToken(headerName,null, tokenValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void constructorEmptyParameterName() {
|
public void constructorEmptyParameterName() {
|
||||||
new CsrfToken(headerName,"", tokenValue);
|
new DefaultCsrfToken(headerName,"", tokenValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void constructorNullTokenValue() {
|
public void constructorNullTokenValue() {
|
||||||
new CsrfToken(headerName,parameterName, null);
|
new DefaultCsrfToken(headerName,parameterName, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void constructorEmptyTokenValue() {
|
public void constructorEmptyTokenValue() {
|
||||||
new CsrfToken(headerName,parameterName, "");
|
new DefaultCsrfToken(headerName,parameterName, "");
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -42,23 +42,23 @@ public class HttpSessionCsrfTokenRepositoryTests {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void generateAndSaveToken() {
|
public void generateToken() {
|
||||||
token = repo.generateAndSaveToken(request, response);
|
token = repo.generateToken(request);
|
||||||
|
|
||||||
assertThat(token.getParameterName()).isEqualTo("_csrf");
|
assertThat(token.getParameterName()).isEqualTo("_csrf");
|
||||||
assertThat(token.getToken()).isNotEmpty();
|
assertThat(token.getToken()).isNotEmpty();
|
||||||
|
|
||||||
CsrfToken loadedToken = repo.loadToken(request);
|
CsrfToken loadedToken = repo.loadToken(request);
|
||||||
|
|
||||||
assertThat(loadedToken).isEqualTo(token);
|
assertThat(loadedToken).isNull();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void generateAndSaveTokenCustomParameter() {
|
public void generateCustomParameter() {
|
||||||
String paramName = "_csrf";
|
String paramName = "_csrf";
|
||||||
repo.setParameterName(paramName);
|
repo.setParameterName(paramName);
|
||||||
|
|
||||||
token = repo.generateAndSaveToken(request, response);
|
token = repo.generateToken(request);
|
||||||
|
|
||||||
assertThat(token.getParameterName()).isEqualTo(paramName);
|
assertThat(token.getParameterName()).isEqualTo(paramName);
|
||||||
assertThat(token.getToken()).isNotEmpty();
|
assertThat(token.getToken()).isNotEmpty();
|
||||||
|
@ -71,7 +71,7 @@ public class HttpSessionCsrfTokenRepositoryTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void saveToken() {
|
public void saveToken() {
|
||||||
CsrfToken tokenToSave = new CsrfToken("123", "abc", "def");
|
CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def");
|
||||||
repo.saveToken(tokenToSave, request, response);
|
repo.saveToken(tokenToSave, request, response);
|
||||||
|
|
||||||
String attrName = request.getSession().getAttributeNames()
|
String attrName = request.getSession().getAttributeNames()
|
||||||
|
@ -84,7 +84,7 @@ public class HttpSessionCsrfTokenRepositoryTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void saveTokenCustomSessionAttribute() {
|
public void saveTokenCustomSessionAttribute() {
|
||||||
CsrfToken tokenToSave = new CsrfToken("123", "abc", "def");
|
CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def");
|
||||||
String sessionAttributeName = "custom";
|
String sessionAttributeName = "custom";
|
||||||
repo.setSessionAttributeName(sessionAttributeName);
|
repo.setSessionAttributeName(sessionAttributeName);
|
||||||
repo.saveToken(tokenToSave, request, response);
|
repo.saveToken(tokenToSave, request, response);
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.junit.Test;
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.mock.web.MockHttpServletResponse;
|
import org.springframework.mock.web.MockHttpServletResponse;
|
||||||
import org.springframework.security.web.csrf.CsrfToken;
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author Rob Winch
|
* @author Rob Winch
|
||||||
|
@ -51,7 +52,7 @@ public class CsrfRequestDataValueProcessorTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void getExtraHiddenFieldsHasCsrfToken() {
|
public void getExtraHiddenFieldsHasCsrfToken() {
|
||||||
CsrfToken token = new CsrfToken("1", "a", "b");
|
CsrfToken token = new DefaultCsrfToken("1", "a", "b");
|
||||||
request.setAttribute(CsrfToken.class.getName(), token);
|
request.setAttribute(CsrfToken.class.getName(), token);
|
||||||
Map<String,String> expected = new HashMap<String,String>();
|
Map<String,String> expected = new HashMap<String,String>();
|
||||||
expected.put(token.getParameterName(),token.getToken());
|
expected.put(token.getParameterName(),token.getToken());
|
||||||
|
|
Loading…
Reference in New Issue