SEC-2276: Delay saving CsrfToken until token is accessed

This also removed the CsrfToken from the response headers to prevent the
token from being saved. If user's wish to return the CsrfToken in the
response headers, they should use the CsrfToken found on the request.
This commit is contained in:
Rob Winch 2013-08-24 23:28:15 -05:00
parent c131fb6379
commit 48283ec004
14 changed files with 355 additions and 159 deletions

View File

@ -36,6 +36,7 @@ import org.springframework.security.web.access.intercept.FilterSecurityIntercept
import org.springframework.security.web.context.HttpRequestResponseHolder
import org.springframework.security.web.context.HttpSessionSecurityContextRepository
import org.springframework.security.web.csrf.CsrfToken
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository
import spock.lang.AutoCleanup
@ -69,7 +70,7 @@ abstract class BaseSpringSpec extends Specification {
}
def setupCsrf(csrfTokenValue="BaseSpringSpec_CSRFTOKEN") {
csrfToken = new CsrfToken("X-CSRF-TOKEN","_csrf",csrfTokenValue)
csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf",csrfTokenValue)
new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request,response)
request.setParameter(csrfToken.parameterName, csrfToken.token)
}

View File

@ -79,8 +79,7 @@ class WebSecurityConfigurerAdapterTests extends BaseSpringSpec {
'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
'Pragma':'no-cache',
'X-XSS-Protection' : '1; mode=block',
'X-CSRF-TOKEN' : csrfToken.token]
'X-XSS-Protection' : '1; mode=block']
}
@EnableWebSecurity

View File

@ -49,8 +49,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
'Pragma':'no-cache',
'X-XSS-Protection' : '1; mode=block',
'X-CSRF-TOKEN' : csrfToken.token]
'X-XSS-Protection' : '1; mode=block']
}
@Configuration
@ -70,8 +69,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
springSecurityFilterChain.doFilter(request,response,chain)
then:
responseHeaders == ['Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
'Pragma':'no-cache',
'X-CSRF-TOKEN' : csrfToken.token]
'Pragma':'no-cache']
}
@Configuration
@ -91,8 +89,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
when:
springSecurityFilterChain.doFilter(request,response,chain)
then:
responseHeaders == ['Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
'X-CSRF-TOKEN' : csrfToken.token]
responseHeaders == ['Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains']
}
@Configuration
@ -111,8 +108,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
when:
springSecurityFilterChain.doFilter(request,response,chain)
then:
responseHeaders == ['Strict-Transport-Security': 'max-age=15768000',
'X-CSRF-TOKEN' : csrfToken.token]
responseHeaders == ['Strict-Transport-Security': 'max-age=15768000']
}
@Configuration
@ -133,8 +129,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
when:
springSecurityFilterChain.doFilter(request,response,chain)
then:
responseHeaders == ['X-Frame-Options': 'SAMEORIGIN',
'X-CSRF-TOKEN' : csrfToken.token]
responseHeaders == ['X-Frame-Options': 'SAMEORIGIN']
}
@Configuration
@ -156,8 +151,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
when:
springSecurityFilterChain.doFilter(request,response,chain)
then:
responseHeaders == ['X-Frame-Options': 'ALLOW-FROM https://example.com',
'X-CSRF-TOKEN' : csrfToken.token]
responseHeaders == ['X-Frame-Options': 'ALLOW-FROM https://example.com']
}
@ -178,8 +172,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
when:
springSecurityFilterChain.doFilter(request,response,chain)
then:
responseHeaders == ['X-XSS-Protection': '1; mode=block',
'X-CSRF-TOKEN' : csrfToken.token]
responseHeaders == ['X-XSS-Protection': '1; mode=block']
}
@Configuration
@ -199,8 +192,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
when:
springSecurityFilterChain.doFilter(request,response,chain)
then:
responseHeaders == ['X-XSS-Protection': '1',
'X-CSRF-TOKEN' : csrfToken.token]
responseHeaders == ['X-XSS-Protection': '1']
}
@Configuration
@ -220,8 +212,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
when:
springSecurityFilterChain.doFilter(request,response,chain)
then:
responseHeaders == ['X-Content-Type-Options': 'nosniff',
'X-CSRF-TOKEN' : csrfToken.token]
responseHeaders == ['X-Content-Type-Options': 'nosniff']
}
@Configuration
@ -243,8 +234,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
when:
springSecurityFilterChain.doFilter(request,response,chain)
then:
responseHeaders == ['customHeaderName': 'customHeaderValue',
'X-CSRF-TOKEN' : csrfToken.token]
responseHeaders == ['customHeaderName': 'customHeaderValue']
}
@Configuration

View File

@ -29,6 +29,7 @@ import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.csrf.CsrfFilter
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor
import org.springframework.security.web.util.RequestMatcher
@ -113,7 +114,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
mockBean(CsrfTokenRepository,'repo')
createAppContext()
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
request.setParameter(token.parameterName,token.token)
request.servletPath = "/some-url"
@ -147,7 +148,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
mockBean(CsrfTokenRepository,'repo')
createAppContext()
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
request.setParameter(token.parameterName,token.token)
request.servletPath = "/some-url"
@ -200,7 +201,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
mockBean(CsrfTokenRepository,'repo')
createAppContext()
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
request.setParameter(token.parameterName,token.token)
request.method = "POST"
@ -223,7 +224,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
mockBean(CsrfTokenRepository,'repo')
createAppContext()
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
request.setParameter(token.parameterName,token.token)
request.method = "POST"
@ -244,7 +245,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
mockBean(CsrfTokenRepository,'repo')
createAppContext()
CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
request.setParameter(token.parameterName,token.token)
request.method = "POST"

View File

@ -40,7 +40,6 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
@ -96,7 +95,9 @@ public class SessionManagementConfigurerServlet31Tests {
request.setMethod("POST");
request.setParameter("username", "user");
request.setParameter("password", "password");
CsrfToken token = new HttpSessionCsrfTokenRepository().generateAndSaveToken(request, response);
HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository();
CsrfToken token = repository.generateToken(request);
repository.saveToken(token, request, response);
request.setParameter(token.getParameterName(),token.getToken());
when(ReflectionUtils.findMethod(HttpServletRequest.class, "changeSessionId")).thenReturn(method);

View File

@ -70,11 +70,11 @@ public final class CsrfFilter extends OncePerRequestFilter {
throws ServletException, IOException {
CsrfToken csrfToken = tokenRepository.loadToken(request);
if(csrfToken == null) {
csrfToken = tokenRepository.generateAndSaveToken(request, response);
CsrfToken generatedToken = tokenRepository.generateToken(request);
csrfToken = new SaveOnAccessCsrfToken(tokenRepository, request, response, generatedToken);
}
request.setAttribute(CsrfToken.class.getName(), csrfToken);
request.setAttribute(csrfToken.getParameterName(), csrfToken);
response.addHeader(csrfToken.getHeaderName(), csrfToken.getToken());
if(!requireCsrfProtectionMatcher.matches(request)) {
filterChain.doFilter(request, response);
@ -128,7 +128,86 @@ public final class CsrfFilter extends OncePerRequestFilter {
this.accessDeniedHandler = accessDeniedHandler;
}
private static class DefaultRequiresCsrfMatcher implements RequestMatcher {
@SuppressWarnings("serial")
private static final class SaveOnAccessCsrfToken implements CsrfToken {
private transient CsrfTokenRepository tokenRepository;
private transient HttpServletRequest request;
private transient HttpServletResponse response;
private final CsrfToken delegate;
public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository,
HttpServletRequest request, HttpServletResponse response,
CsrfToken delegate) {
super();
this.tokenRepository = tokenRepository;
this.request = request;
this.response = response;
this.delegate = delegate;
}
public String getHeaderName() {
return delegate.getHeaderName();
}
public String getParameterName() {
return delegate.getParameterName();
}
public String getToken() {
saveTokenIfNecessary();
return delegate.getToken();
}
@Override
public String toString() {
return "SaveOnAccessCsrfToken [delegate=" + delegate + "]";
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result
+ ((delegate == null) ? 0 : delegate.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj;
if (delegate == null) {
if (other.delegate != null)
return false;
} else if (!delegate.equals(other.delegate))
return false;
return true;
}
private void saveTokenIfNecessary() {
if(this.tokenRepository == null) {
return;
}
synchronized(this) {
if(tokenRepository != null) {
this.tokenRepository.saveToken(delegate, request, response);
this.tokenRepository = null;
this.request = null;
this.response = null;
}
}
}
}
private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
private Pattern allowedMethods = Pattern.compile("^(GET|HEAD|TRACE|OPTIONS)$");
/* (non-Javadoc)

View File

@ -17,37 +17,16 @@ package org.springframework.security.web.csrf;
import java.io.Serializable;
import org.springframework.util.Assert;
/**
* A CSRF token that is used to protect against CSRF attacks.
* Provides the information about an expected CSRF token.
*
* @see DefaultCsrfToken
*
* @author Rob Winch
* @since 3.2
*
*/
@SuppressWarnings("serial")
public final class CsrfToken implements Serializable {
private final String token;
private final String parameterName;
private final String headerName;
/**
* Creates a new instance
* @param headerName the HTTP header name to use
* @param parameterName the HTTP parameter name to use
* @param token the value of the token (i.e. expected value of the HTTP parameter of parametername).
*/
public CsrfToken(String headerName, String parameterName, String token) {
Assert.hasLength(headerName, "headerName cannot be null or empty");
Assert.hasLength(parameterName, "parameterName cannot be null or empty");
Assert.hasLength(token, "token cannot be null or empty");
this.headerName = headerName;
this.parameterName = parameterName;
this.token = token;
}
public interface CsrfToken extends Serializable {
/**
* Gets the HTTP header that the CSRF is populated on the response and can
@ -56,23 +35,18 @@ public final class CsrfToken implements Serializable {
* @return the HTTP header that the CSRF is populated on the response and
* can be placed on requests instead of the parameter
*/
public String getHeaderName() {
return headerName;
}
String getHeaderName();
/**
* Gets the HTTP parameter name that should contain the token. Cannot be null.
* @return the HTTP parameter name that should contain the token.
*/
public String getParameterName() {
return parameterName;
}
String getParameterName();
/**
* Gets the token value. Cannot be null.
* @return the token value
*/
public String getToken() {
return token;
}
}
String getToken();
}

View File

@ -33,17 +33,14 @@ import javax.servlet.http.HttpSession;
public interface CsrfTokenRepository {
/**
* Generates and saves the expected {@link CsrfToken}
* Generates a {@link CsrfToken}
*
* @param request
* the {@link HttpServletRequest} to use
* @param response
* the {@link HttpServletResponse} to use
* @return the {@link CsrfToken} that was generated and saved. Cannot be
* @return the {@link CsrfToken} that was generated. Cannot be
* null.
*/
CsrfToken generateAndSaveToken(HttpServletRequest request,
HttpServletResponse response);
CsrfToken generateToken(HttpServletRequest request);
/**
* Saves the {@link CsrfToken} using the {@link HttpServletRequest} and

View File

@ -0,0 +1,70 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
import org.springframework.util.Assert;
/**
* A CSRF token that is used to protect against CSRF attacks.
*
* @author Rob Winch
* @since 3.2
*/
@SuppressWarnings("serial")
public final class DefaultCsrfToken implements CsrfToken {
private final String token;
private final String parameterName;
private final String headerName;
/**
* Creates a new instance
* @param headerName the HTTP header name to use
* @param parameterName the HTTP parameter name to use
* @param token the value of the token (i.e. expected value of the HTTP parameter of parametername).
*/
public DefaultCsrfToken(String headerName, String parameterName, String token) {
Assert.hasLength(headerName, "headerName cannot be null or empty");
Assert.hasLength(parameterName, "parameterName cannot be null or empty");
Assert.hasLength(token, "token cannot be null or empty");
this.headerName = headerName;
this.parameterName = parameterName;
this.token = token;
}
/* (non-Javadoc)
* @see org.springframework.security.web.csrf.CsrfToken#getHeaderName()
*/
public String getHeaderName() {
return headerName;
}
/* (non-Javadoc)
* @see org.springframework.security.web.csrf.CsrfToken#getParameterName()
*/
public String getParameterName() {
return parameterName;
}
/* (non-Javadoc)
* @see org.springframework.security.web.csrf.CsrfToken#getToken()
*/
public String getToken() {
return token;
}
}

View File

@ -63,14 +63,12 @@ public final class HttpSessionCsrfTokenRepository implements CsrfTokenRepository
return (CsrfToken) request.getSession().getAttribute(sessionAttributeName);
}
/* (non-Javadoc)
* @see org.springframework.security.web.csrf.CsrfTokenRepository#generateNewToken(javax.servlet.http.HttpServletRequest, javax.servlet.http.HttpServletResponse)
/*
* (non-Javadoc)
* @see org.springframework.security.web.csrf.CsrfTokenRepository#generateToken(javax.servlet.http.HttpServletRequest)
*/
public CsrfToken generateAndSaveToken(HttpServletRequest request,
HttpServletResponse response) {
CsrfToken token = new CsrfToken(headerName, parameterName, createNewToken());
saveToken(token, request, response);
return token;
public CsrfToken generateToken(HttpServletRequest request) {
return new DefaultCsrfToken(headerName, parameterName, createNewToken());
}
/**

View File

@ -18,6 +18,7 @@ package org.springframework.security.web.csrf;
import static org.fest.assertions.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
@ -27,8 +28,11 @@ import java.util.Arrays;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.fest.assertions.GenericAssert;
import org.fest.assertions.ObjectAssert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -59,12 +63,12 @@ public class CsrfFilterTests {
private MockHttpServletResponse response;
private CsrfToken token;
private CsrfFilter filter;
@Before
public void setup() {
token = new CsrfToken("headerName","paramName", "csrfTokenValue");
token = new DefaultCsrfToken("headerName", "paramName",
"csrfTokenValue");
resetRequestResponse();
filter = new CsrfFilter(tokenRepository);
filter.setRequireCsrfProtectionMatcher(requestMatcher);
@ -81,171 +85,221 @@ public class CsrfFilterTests {
new CsrfFilter(null);
}
// SEC-2276
@Test
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
public void doFilterDoesNotSaveCsrfTokenUntilAccessed() throws ServletException,
IOException {
when(requestMatcher.matches(request)).thenReturn(false);
when(tokenRepository.generateToken(request)).thenReturn(token);
filter.doFilter(request, response, filterChain);
CsrfToken attrToken = (CsrfToken) request.getAttribute(token.getParameterName());
// no CsrfToken should have been saved yet
verify(tokenRepository,times(0)).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
verify(filterChain).doFilter(request, response);
// access the token
attrToken.getToken();
// now the CsrfToken should have been saved
verify(tokenRepository).saveToken(eq(token), any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterAccessDeniedNoTokenPresent() throws ServletException,
IOException {
when(requestMatcher.matches(request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token);
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
verify(deniedHandler).handle(eq(request), eq(response),
any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain);
}
@Test
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
public void doFilterAccessDeniedIncorrectTokenPresent()
throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token);
request.setParameter(token.getParameterName(), token.getToken()+ " INVALID");
request.setParameter(token.getParameterName(), token.getToken()
+ " INVALID");
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
verify(deniedHandler).handle(eq(request), eq(response),
any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain);
}
@Test
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
public void doFilterAccessDeniedIncorrectTokenPresentHeader()
throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token);
request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID");
request.addHeader(token.getHeaderName(), token.getToken() + " INVALID");
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
verify(deniedHandler).handle(eq(request), eq(response),
any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain);
}
@Test
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException {
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token);
request.setParameter(token.getParameterName(), token.getToken());
request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID");
request.addHeader(token.getHeaderName(), token.getToken() + " INVALID");
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
verify(deniedHandler).handle(eq(request), eq(response),
any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain);
}
@Test
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
public void doFilterNotCsrfRequestExistingToken() throws ServletException,
IOException {
when(requestMatcher.matches(request)).thenReturn(false);
when(tokenRepository.loadToken(request)).thenReturn(token);
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
verify(filterChain).doFilter(request, response);
verifyZeroInteractions(deniedHandler);
}
@Test
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
public void doFilterNotCsrfRequestGenerateToken() throws ServletException,
IOException {
when(requestMatcher.matches(request)).thenReturn(false);
when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token);
when(tokenRepository.generateToken(request))
.thenReturn(token);
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertToken(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
verify(filterChain).doFilter(request, response);
verifyZeroInteractions(deniedHandler);
}
@Test
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
public void doFilterIsCsrfRequestExistingTokenHeader()
throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token);
request.addHeader(token.getHeaderName(), token.getToken());
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
verify(filterChain).doFilter(request, response);
verifyZeroInteractions(deniedHandler);
}
@Test
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException {
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
throws ServletException, IOException {
when(requestMatcher.matches(request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token);
request.setParameter(token.getParameterName(), token.getToken()+ " INVALID");
request.setParameter(token.getParameterName(), token.getToken()
+ " INVALID");
request.addHeader(token.getHeaderName(), token.getToken());
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
verify(filterChain).doFilter(request, response);
verifyZeroInteractions(deniedHandler);
}
@Test
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
public void doFilterIsCsrfRequestExistingToken() throws ServletException,
IOException {
when(requestMatcher.matches(request)).thenReturn(true);
when(tokenRepository.loadToken(request)).thenReturn(token);
request.setParameter(token.getParameterName(), token.getToken());
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
verify(filterChain).doFilter(request, response);
verifyZeroInteractions(deniedHandler);
}
@Test
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
public void doFilterIsCsrfRequestGenerateToken() throws ServletException,
IOException {
when(requestMatcher.matches(request)).thenReturn(true);
when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token);
when(tokenRepository.generateToken(request))
.thenReturn(token);
request.setParameter(token.getParameterName(), token.getToken());
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertToken(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
verify(filterChain).doFilter(request, response);
verifyZeroInteractions(deniedHandler);
}
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods()
throws ServletException, IOException {
filter = new CsrfFilter(tokenRepository);
filter.setAccessDeniedHandler(deniedHandler);
for(String method : Arrays.asList("GET","TRACE", "OPTIONS", "HEAD")) {
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
resetRequestResponse();
when(tokenRepository.loadToken(request)).thenReturn(token);
request.setMethod(method);
@ -258,24 +312,28 @@ public class CsrfFilterTests {
}
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods()
throws ServletException, IOException {
filter = new CsrfFilter(tokenRepository);
filter.setAccessDeniedHandler(deniedHandler);
for(String method : Arrays.asList("POST","PUT", "PATCH", "DELETE", "INVALID")) {
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE",
"INVALID")) {
resetRequestResponse();
when(tokenRepository.loadToken(request)).thenReturn(token);
request.setMethod(method);
filter.doFilter(request, response, filterChain);
verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
verify(deniedHandler).handle(eq(request), eq(response),
any(InvalidCsrfTokenException.class));
verifyZeroInteractions(filterChain);
}
}
@Test
public void doFilterDefaultAccessDenied() throws ServletException, IOException {
public void doFilterDefaultAccessDenied() throws ServletException,
IOException {
filter = new CsrfFilter(tokenRepository);
filter.setRequireCsrfProtectionMatcher(requestMatcher);
when(requestMatcher.matches(request)).thenReturn(true);
@ -283,11 +341,13 @@ public class CsrfFilterTests {
filter.doFilter(request, response, filterChain);
assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
token);
assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
token);
assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
assertThat(response.getStatus()).isEqualTo(
HttpServletResponse.SC_FORBIDDEN);
verifyZeroInteractions(filterChain);
}
@ -300,4 +360,29 @@ public class CsrfFilterTests {
public void setAccessDeniedHandlerNull() {
filter.setAccessDeniedHandler(null);
}
private static final CsrfTokenAssert assertToken(Object token) {
return new CsrfTokenAssert((CsrfToken)token);
}
private static class CsrfTokenAssert extends
GenericAssert<CsrfTokenAssert, CsrfToken> {
/**
* Creates a new </code>{@link ObjectAssert}</code>.
*
* @param actual
* the target to verify.
*/
protected CsrfTokenAssert(CsrfToken actual) {
super(CsrfTokenAssert.class, actual);
}
public CsrfTokenAssert isEqualTo(CsrfToken expected) {
assertThat(actual.getHeaderName()).isEqualTo(expected.getHeaderName());
assertThat(actual.getParameterName()).isEqualTo(expected.getParameterName());
assertThat(actual.getToken()).isEqualTo(expected.getToken());
return this;
}
}
}

View File

@ -21,38 +21,38 @@ import org.junit.Test;
* @author Rob Winch
*
*/
public class CsrfTokenTests {
public class DefaultCsrfTokenTests {
private final String headerName = "headerName";
private final String parameterName = "parameterName";
private final String tokenValue = "tokenValue";
@Test(expected = IllegalArgumentException.class)
public void constructorNullHeaderName() {
new CsrfToken(null,parameterName, tokenValue);
new DefaultCsrfToken(null,parameterName, tokenValue);
}
@Test(expected = IllegalArgumentException.class)
public void constructorEmptyHeaderName() {
new CsrfToken("",parameterName, tokenValue);
new DefaultCsrfToken("",parameterName, tokenValue);
}
@Test(expected = IllegalArgumentException.class)
public void constructorNullParameterName() {
new CsrfToken(headerName,null, tokenValue);
new DefaultCsrfToken(headerName,null, tokenValue);
}
@Test(expected = IllegalArgumentException.class)
public void constructorEmptyParameterName() {
new CsrfToken(headerName,"", tokenValue);
new DefaultCsrfToken(headerName,"", tokenValue);
}
@Test(expected = IllegalArgumentException.class)
public void constructorNullTokenValue() {
new CsrfToken(headerName,parameterName, null);
new DefaultCsrfToken(headerName,parameterName, null);
}
@Test(expected = IllegalArgumentException.class)
public void constructorEmptyTokenValue() {
new CsrfToken(headerName,parameterName, "");
new DefaultCsrfToken(headerName,parameterName, "");
}
}

View File

@ -42,23 +42,23 @@ public class HttpSessionCsrfTokenRepositoryTests {
}
@Test
public void generateAndSaveToken() {
token = repo.generateAndSaveToken(request, response);
public void generateToken() {
token = repo.generateToken(request);
assertThat(token.getParameterName()).isEqualTo("_csrf");
assertThat(token.getToken()).isNotEmpty();
CsrfToken loadedToken = repo.loadToken(request);
assertThat(loadedToken).isEqualTo(token);
assertThat(loadedToken).isNull();
}
@Test
public void generateAndSaveTokenCustomParameter() {
public void generateCustomParameter() {
String paramName = "_csrf";
repo.setParameterName(paramName);
token = repo.generateAndSaveToken(request, response);
token = repo.generateToken(request);
assertThat(token.getParameterName()).isEqualTo(paramName);
assertThat(token.getToken()).isNotEmpty();
@ -71,7 +71,7 @@ public class HttpSessionCsrfTokenRepositoryTests {
@Test
public void saveToken() {
CsrfToken tokenToSave = new CsrfToken("123", "abc", "def");
CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def");
repo.saveToken(tokenToSave, request, response);
String attrName = request.getSession().getAttributeNames()
@ -84,7 +84,7 @@ public class HttpSessionCsrfTokenRepositoryTests {
@Test
public void saveTokenCustomSessionAttribute() {
CsrfToken tokenToSave = new CsrfToken("123", "abc", "def");
CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def");
String sessionAttributeName = "custom";
repo.setSessionAttributeName(sessionAttributeName);
repo.saveToken(tokenToSave, request, response);

View File

@ -25,6 +25,7 @@ import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
/**
* @author Rob Winch
@ -51,7 +52,7 @@ public class CsrfRequestDataValueProcessorTests {
@Test
public void getExtraHiddenFieldsHasCsrfToken() {
CsrfToken token = new CsrfToken("1", "a", "b");
CsrfToken token = new DefaultCsrfToken("1", "a", "b");
request.setAttribute(CsrfToken.class.getName(), token);
Map<String,String> expected = new HashMap<String,String>();
expected.put(token.getParameterName(),token.getToken());