From 48283ec0049b4245a8887ffaf89d60416fb12203 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Sat, 24 Aug 2013 23:28:15 -0500 Subject: [PATCH] SEC-2276: Delay saving CsrfToken until token is accessed This also removed the CsrfToken from the response headers to prevent the token from being saved. If user's wish to return the CsrfToken in the response headers, they should use the CsrfToken found on the request. --- .../config/annotation/BaseSpringSpec.groovy | 3 +- .../WebSecurityConfigurerAdapterTests.groovy | 3 +- .../NamespaceHttpHeadersTests.groovy | 30 +-- .../config/http/CsrfConfigTests.groovy | 11 +- ...ionManagementConfigurerServlet31Tests.java | 5 +- .../security/web/csrf/CsrfFilter.java | 85 ++++++- .../security/web/csrf/CsrfToken.java | 46 +--- .../web/csrf/CsrfTokenRepository.java | 9 +- .../security/web/csrf/DefaultCsrfToken.java | 70 ++++++ .../csrf/HttpSessionCsrfTokenRepository.java | 12 +- .../security/web/csrf/CsrfFilterTests.java | 209 ++++++++++++------ ...nTests.java => DefaultCsrfTokenTests.java} | 14 +- .../HttpSessionCsrfTokenRepositoryTests.java | 14 +- .../CsrfRequestDataValueProcessorTests.java | 3 +- 14 files changed, 355 insertions(+), 159 deletions(-) create mode 100644 web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java rename web/src/test/java/org/springframework/security/web/csrf/{CsrfTokenTests.java => DefaultCsrfTokenTests.java} (79%) diff --git a/config/src/test/groovy/org/springframework/security/config/annotation/BaseSpringSpec.groovy b/config/src/test/groovy/org/springframework/security/config/annotation/BaseSpringSpec.groovy index a29320da2b..489501e0e9 100644 --- a/config/src/test/groovy/org/springframework/security/config/annotation/BaseSpringSpec.groovy +++ b/config/src/test/groovy/org/springframework/security/config/annotation/BaseSpringSpec.groovy @@ -36,6 +36,7 @@ import org.springframework.security.web.access.intercept.FilterSecurityIntercept import org.springframework.security.web.context.HttpRequestResponseHolder import org.springframework.security.web.context.HttpSessionSecurityContextRepository import org.springframework.security.web.csrf.CsrfToken +import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository import spock.lang.AutoCleanup @@ -69,7 +70,7 @@ abstract class BaseSpringSpec extends Specification { } def setupCsrf(csrfTokenValue="BaseSpringSpec_CSRFTOKEN") { - csrfToken = new CsrfToken("X-CSRF-TOKEN","_csrf",csrfTokenValue) + csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf",csrfTokenValue) new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request,response) request.setParameter(csrfToken.parameterName, csrfToken.token) } diff --git a/config/src/test/groovy/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.groovy b/config/src/test/groovy/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.groovy index 32a6e40f53..6e15ea0cc9 100644 --- a/config/src/test/groovy/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.groovy +++ b/config/src/test/groovy/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.groovy @@ -79,8 +79,7 @@ class WebSecurityConfigurerAdapterTests extends BaseSpringSpec { 'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains', 'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate', 'Pragma':'no-cache', - 'X-XSS-Protection' : '1; mode=block', - 'X-CSRF-TOKEN' : csrfToken.token] + 'X-XSS-Protection' : '1; mode=block'] } @EnableWebSecurity diff --git a/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.groovy b/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.groovy index e5d4a37258..66595012b8 100644 --- a/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.groovy +++ b/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.groovy @@ -49,8 +49,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { 'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains', 'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate', 'Pragma':'no-cache', - 'X-XSS-Protection' : '1; mode=block', - 'X-CSRF-TOKEN' : csrfToken.token] + 'X-XSS-Protection' : '1; mode=block'] } @Configuration @@ -70,8 +69,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { springSecurityFilterChain.doFilter(request,response,chain) then: responseHeaders == ['Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate', - 'Pragma':'no-cache', - 'X-CSRF-TOKEN' : csrfToken.token] + 'Pragma':'no-cache'] } @Configuration @@ -91,8 +89,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains'] } @Configuration @@ -111,8 +108,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['Strict-Transport-Security': 'max-age=15768000', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['Strict-Transport-Security': 'max-age=15768000'] } @Configuration @@ -133,8 +129,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['X-Frame-Options': 'SAMEORIGIN', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['X-Frame-Options': 'SAMEORIGIN'] } @Configuration @@ -156,8 +151,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['X-Frame-Options': 'ALLOW-FROM https://example.com', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['X-Frame-Options': 'ALLOW-FROM https://example.com'] } @@ -178,8 +172,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['X-XSS-Protection': '1; mode=block', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['X-XSS-Protection': '1; mode=block'] } @Configuration @@ -199,8 +192,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['X-XSS-Protection': '1', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['X-XSS-Protection': '1'] } @Configuration @@ -220,8 +212,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['X-Content-Type-Options': 'nosniff', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['X-Content-Type-Options': 'nosniff'] } @Configuration @@ -243,8 +234,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['customHeaderName': 'customHeaderValue', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['customHeaderName': 'customHeaderValue'] } @Configuration diff --git a/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy b/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy index 49d5c6c80c..72693e700a 100644 --- a/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy +++ b/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy @@ -29,6 +29,7 @@ import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.csrf.CsrfFilter import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor import org.springframework.security.web.util.RequestMatcher @@ -113,7 +114,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests { mockBean(CsrfTokenRepository,'repo') createAppContext() CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc") + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token) request.servletPath = "/some-url" @@ -147,7 +148,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests { mockBean(CsrfTokenRepository,'repo') createAppContext() CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc") + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token) request.servletPath = "/some-url" @@ -200,7 +201,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests { mockBean(CsrfTokenRepository,'repo') createAppContext() CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc") + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token) request.method = "POST" @@ -223,7 +224,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests { mockBean(CsrfTokenRepository,'repo') createAppContext() CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc") + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token) request.method = "POST" @@ -244,7 +245,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests { mockBean(CsrfTokenRepository,'repo') createAppContext() CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc") + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token) request.method = "POST" diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java index 6f834abd76..e090e914e7 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java @@ -40,7 +40,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -96,7 +95,9 @@ public class SessionManagementConfigurerServlet31Tests { request.setMethod("POST"); request.setParameter("username", "user"); request.setParameter("password", "password"); - CsrfToken token = new HttpSessionCsrfTokenRepository().generateAndSaveToken(request, response); + HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); + CsrfToken token = repository.generateToken(request); + repository.saveToken(token, request, response); request.setParameter(token.getParameterName(),token.getToken()); when(ReflectionUtils.findMethod(HttpServletRequest.class, "changeSessionId")).thenReturn(method); diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 5851dc7c13..e6de1f4b36 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -70,11 +70,11 @@ public final class CsrfFilter extends OncePerRequestFilter { throws ServletException, IOException { CsrfToken csrfToken = tokenRepository.loadToken(request); if(csrfToken == null) { - csrfToken = tokenRepository.generateAndSaveToken(request, response); + CsrfToken generatedToken = tokenRepository.generateToken(request); + csrfToken = new SaveOnAccessCsrfToken(tokenRepository, request, response, generatedToken); } request.setAttribute(CsrfToken.class.getName(), csrfToken); request.setAttribute(csrfToken.getParameterName(), csrfToken); - response.addHeader(csrfToken.getHeaderName(), csrfToken.getToken()); if(!requireCsrfProtectionMatcher.matches(request)) { filterChain.doFilter(request, response); @@ -128,7 +128,86 @@ public final class CsrfFilter extends OncePerRequestFilter { this.accessDeniedHandler = accessDeniedHandler; } - private static class DefaultRequiresCsrfMatcher implements RequestMatcher { + @SuppressWarnings("serial") + private static final class SaveOnAccessCsrfToken implements CsrfToken { + private transient CsrfTokenRepository tokenRepository; + private transient HttpServletRequest request; + private transient HttpServletResponse response; + + private final CsrfToken delegate; + + public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository, + HttpServletRequest request, HttpServletResponse response, + CsrfToken delegate) { + super(); + this.tokenRepository = tokenRepository; + this.request = request; + this.response = response; + this.delegate = delegate; + } + + public String getHeaderName() { + return delegate.getHeaderName(); + } + + public String getParameterName() { + return delegate.getParameterName(); + } + + public String getToken() { + saveTokenIfNecessary(); + return delegate.getToken(); + } + + @Override + public String toString() { + return "SaveOnAccessCsrfToken [delegate=" + delegate + "]"; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + + ((delegate == null) ? 0 : delegate.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj; + if (delegate == null) { + if (other.delegate != null) + return false; + } else if (!delegate.equals(other.delegate)) + return false; + return true; + } + + private void saveTokenIfNecessary() { + if(this.tokenRepository == null) { + return; + } + + synchronized(this) { + if(tokenRepository != null) { + this.tokenRepository.saveToken(delegate, request, response); + this.tokenRepository = null; + this.request = null; + this.response = null; + } + } + } + + } + + private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { private Pattern allowedMethods = Pattern.compile("^(GET|HEAD|TRACE|OPTIONS)$"); /* (non-Javadoc) diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java index 46f680d5b7..1f98c405e8 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java @@ -17,37 +17,16 @@ package org.springframework.security.web.csrf; import java.io.Serializable; -import org.springframework.util.Assert; - /** - * A CSRF token that is used to protect against CSRF attacks. + * Provides the information about an expected CSRF token. + * + * @see DefaultCsrfToken * * @author Rob Winch * @since 3.2 + * */ -@SuppressWarnings("serial") -public final class CsrfToken implements Serializable { - - private final String token; - - private final String parameterName; - - private final String headerName; - - /** - * Creates a new instance - * @param headerName the HTTP header name to use - * @param parameterName the HTTP parameter name to use - * @param token the value of the token (i.e. expected value of the HTTP parameter of parametername). - */ - public CsrfToken(String headerName, String parameterName, String token) { - Assert.hasLength(headerName, "headerName cannot be null or empty"); - Assert.hasLength(parameterName, "parameterName cannot be null or empty"); - Assert.hasLength(token, "token cannot be null or empty"); - this.headerName = headerName; - this.parameterName = parameterName; - this.token = token; - } +public interface CsrfToken extends Serializable { /** * Gets the HTTP header that the CSRF is populated on the response and can @@ -56,23 +35,18 @@ public final class CsrfToken implements Serializable { * @return the HTTP header that the CSRF is populated on the response and * can be placed on requests instead of the parameter */ - public String getHeaderName() { - return headerName; - } + String getHeaderName(); /** * Gets the HTTP parameter name that should contain the token. Cannot be null. * @return the HTTP parameter name that should contain the token. */ - public String getParameterName() { - return parameterName; - } + String getParameterName(); /** * Gets the token value. Cannot be null. * @return the token value */ - public String getToken() { - return token; - } -} + String getToken(); + +} \ No newline at end of file diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java index 4b705fd27c..36a8109a7d 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java @@ -33,17 +33,14 @@ import javax.servlet.http.HttpSession; public interface CsrfTokenRepository { /** - * Generates and saves the expected {@link CsrfToken} + * Generates a {@link CsrfToken} * * @param request * the {@link HttpServletRequest} to use - * @param response - * the {@link HttpServletResponse} to use - * @return the {@link CsrfToken} that was generated and saved. Cannot be + * @return the {@link CsrfToken} that was generated. Cannot be * null. */ - CsrfToken generateAndSaveToken(HttpServletRequest request, - HttpServletResponse response); + CsrfToken generateToken(HttpServletRequest request); /** * Saves the {@link CsrfToken} using the {@link HttpServletRequest} and diff --git a/web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java new file mode 100644 index 0000000000..8c9dc20dbb --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.csrf; + +import org.springframework.util.Assert; + +/** + * A CSRF token that is used to protect against CSRF attacks. + * + * @author Rob Winch + * @since 3.2 + */ +@SuppressWarnings("serial") +public final class DefaultCsrfToken implements CsrfToken { + + private final String token; + + private final String parameterName; + + private final String headerName; + + /** + * Creates a new instance + * @param headerName the HTTP header name to use + * @param parameterName the HTTP parameter name to use + * @param token the value of the token (i.e. expected value of the HTTP parameter of parametername). + */ + public DefaultCsrfToken(String headerName, String parameterName, String token) { + Assert.hasLength(headerName, "headerName cannot be null or empty"); + Assert.hasLength(parameterName, "parameterName cannot be null or empty"); + Assert.hasLength(token, "token cannot be null or empty"); + this.headerName = headerName; + this.parameterName = parameterName; + this.token = token; + } + + /* (non-Javadoc) + * @see org.springframework.security.web.csrf.CsrfToken#getHeaderName() + */ + public String getHeaderName() { + return headerName; + } + + /* (non-Javadoc) + * @see org.springframework.security.web.csrf.CsrfToken#getParameterName() + */ + public String getParameterName() { + return parameterName; + } + + /* (non-Javadoc) + * @see org.springframework.security.web.csrf.CsrfToken#getToken() + */ + public String getToken() { + return token; + } +} diff --git a/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java index 6eecd01b63..c9927fd91c 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java @@ -63,14 +63,12 @@ public final class HttpSessionCsrfTokenRepository implements CsrfTokenRepository return (CsrfToken) request.getSession().getAttribute(sessionAttributeName); } - /* (non-Javadoc) - * @see org.springframework.security.web.csrf.CsrfTokenRepository#generateNewToken(javax.servlet.http.HttpServletRequest, javax.servlet.http.HttpServletResponse) + /* + * (non-Javadoc) + * @see org.springframework.security.web.csrf.CsrfTokenRepository#generateToken(javax.servlet.http.HttpServletRequest) */ - public CsrfToken generateAndSaveToken(HttpServletRequest request, - HttpServletResponse response) { - CsrfToken token = new CsrfToken(headerName, parameterName, createNewToken()); - saveToken(token, request, response); - return token; + public CsrfToken generateToken(HttpServletRequest request) { + return new DefaultCsrfToken(headerName, parameterName, createNewToken()); } /** diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index fc1986608f..c9d32a68a7 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -18,6 +18,7 @@ package org.springframework.security.web.csrf; import static org.fest.assertions.Assertions.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; @@ -27,8 +28,11 @@ import java.util.Arrays; import javax.servlet.FilterChain; import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.fest.assertions.GenericAssert; +import org.fest.assertions.ObjectAssert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -59,12 +63,12 @@ public class CsrfFilterTests { private MockHttpServletResponse response; private CsrfToken token; - private CsrfFilter filter; @Before public void setup() { - token = new CsrfToken("headerName","paramName", "csrfTokenValue"); + token = new DefaultCsrfToken("headerName", "paramName", + "csrfTokenValue"); resetRequestResponse(); filter = new CsrfFilter(tokenRepository); filter.setRequireCsrfProtectionMatcher(requestMatcher); @@ -81,171 +85,221 @@ public class CsrfFilterTests { new CsrfFilter(null); } + // SEC-2276 @Test - public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { + public void doFilterDoesNotSaveCsrfTokenUntilAccessed() throws ServletException, + IOException { + when(requestMatcher.matches(request)).thenReturn(false); + when(tokenRepository.generateToken(request)).thenReturn(token); + + filter.doFilter(request, response, filterChain); + CsrfToken attrToken = (CsrfToken) request.getAttribute(token.getParameterName()); + + // no CsrfToken should have been saved yet + verify(tokenRepository,times(0)).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(filterChain).doFilter(request, response); + + // access the token + attrToken.getToken(); + + // now the CsrfToken should have been saved + verify(tokenRepository).saveToken(eq(token), any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterAccessDeniedNoTokenPresent() throws ServletException, + IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); - verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class)); + verify(deniedHandler).handle(eq(request), eq(response), + any(InvalidCsrfTokenException.class)); verifyZeroInteractions(filterChain); } @Test - public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { + public void doFilterAccessDeniedIncorrectTokenPresent() + throws ServletException, IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); - request.setParameter(token.getParameterName(), token.getToken()+ " INVALID"); + request.setParameter(token.getParameterName(), token.getToken() + + " INVALID"); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); - verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class)); + verify(deniedHandler).handle(eq(request), eq(response), + any(InvalidCsrfTokenException.class)); verifyZeroInteractions(filterChain); } @Test - public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { + public void doFilterAccessDeniedIncorrectTokenPresentHeader() + throws ServletException, IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); - request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID"); + request.addHeader(token.getHeaderName(), token.getToken() + " INVALID"); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); - verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class)); + verify(deniedHandler).handle(eq(request), eq(response), + any(InvalidCsrfTokenException.class)); verifyZeroInteractions(filterChain); } @Test - public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException { + public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() + throws ServletException, IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); request.setParameter(token.getParameterName(), token.getToken()); - request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID"); + request.addHeader(token.getHeaderName(), token.getToken() + " INVALID"); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); - verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class)); + verify(deniedHandler).handle(eq(request), eq(response), + any(InvalidCsrfTokenException.class)); verifyZeroInteractions(filterChain); } @Test - public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { + public void doFilterNotCsrfRequestExistingToken() throws ServletException, + IOException { when(requestMatcher.matches(request)).thenReturn(false); when(tokenRepository.loadToken(request)).thenReturn(token); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { + public void doFilterNotCsrfRequestGenerateToken() throws ServletException, + IOException { when(requestMatcher.matches(request)).thenReturn(false); - when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token); + when(tokenRepository.generateToken(request)) + .thenReturn(token); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertToken(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { + public void doFilterIsCsrfRequestExistingTokenHeader() + throws ServletException, IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); request.addHeader(token.getHeaderName(), token.getToken()); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException { + public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() + throws ServletException, IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); - request.setParameter(token.getParameterName(), token.getToken()+ " INVALID"); + request.setParameter(token.getParameterName(), token.getToken() + + " INVALID"); request.addHeader(token.getHeaderName(), token.getToken()); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { + public void doFilterIsCsrfRequestExistingToken() throws ServletException, + IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); request.setParameter(token.getParameterName(), token.getToken()); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { + public void doFilterIsCsrfRequestGenerateToken() throws ServletException, + IOException { when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token); + when(tokenRepository.generateToken(request)) + .thenReturn(token); request.setParameter(token.getParameterName(), token.getToken()); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertToken(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { + public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() + throws ServletException, IOException { filter = new CsrfFilter(tokenRepository); filter.setAccessDeniedHandler(deniedHandler); - for(String method : Arrays.asList("GET","TRACE", "OPTIONS", "HEAD")) { + for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { resetRequestResponse(); when(tokenRepository.loadToken(request)).thenReturn(token); request.setMethod(method); @@ -258,24 +312,28 @@ public class CsrfFilterTests { } @Test - public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { + public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() + throws ServletException, IOException { filter = new CsrfFilter(tokenRepository); filter.setAccessDeniedHandler(deniedHandler); - for(String method : Arrays.asList("POST","PUT", "PATCH", "DELETE", "INVALID")) { + for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", + "INVALID")) { resetRequestResponse(); when(tokenRepository.loadToken(request)).thenReturn(token); request.setMethod(method); filter.doFilter(request, response, filterChain); - verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class)); + verify(deniedHandler).handle(eq(request), eq(response), + any(InvalidCsrfTokenException.class)); verifyZeroInteractions(filterChain); } } @Test - public void doFilterDefaultAccessDenied() throws ServletException, IOException { + public void doFilterDefaultAccessDenied() throws ServletException, + IOException { filter = new CsrfFilter(tokenRepository); filter.setRequireCsrfProtectionMatcher(requestMatcher); when(requestMatcher.matches(request)).thenReturn(true); @@ -283,11 +341,13 @@ public class CsrfFilterTests { filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + assertThat(response.getStatus()).isEqualTo( + HttpServletResponse.SC_FORBIDDEN); verifyZeroInteractions(filterChain); } @@ -300,4 +360,29 @@ public class CsrfFilterTests { public void setAccessDeniedHandlerNull() { filter.setAccessDeniedHandler(null); } + + private static final CsrfTokenAssert assertToken(Object token) { + return new CsrfTokenAssert((CsrfToken)token); + } + + private static class CsrfTokenAssert extends + GenericAssert { + + /** + * Creates a new {@link ObjectAssert}. + * + * @param actual + * the target to verify. + */ + protected CsrfTokenAssert(CsrfToken actual) { + super(CsrfTokenAssert.class, actual); + } + + public CsrfTokenAssert isEqualTo(CsrfToken expected) { + assertThat(actual.getHeaderName()).isEqualTo(expected.getHeaderName()); + assertThat(actual.getParameterName()).isEqualTo(expected.getParameterName()); + assertThat(actual.getToken()).isEqualTo(expected.getToken()); + return this; + } + } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenTests.java b/web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java similarity index 79% rename from web/src/test/java/org/springframework/security/web/csrf/CsrfTokenTests.java rename to web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java index f3997b5f28..531f1f592a 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java @@ -21,38 +21,38 @@ import org.junit.Test; * @author Rob Winch * */ -public class CsrfTokenTests { +public class DefaultCsrfTokenTests { private final String headerName = "headerName"; private final String parameterName = "parameterName"; private final String tokenValue = "tokenValue"; @Test(expected = IllegalArgumentException.class) public void constructorNullHeaderName() { - new CsrfToken(null,parameterName, tokenValue); + new DefaultCsrfToken(null,parameterName, tokenValue); } @Test(expected = IllegalArgumentException.class) public void constructorEmptyHeaderName() { - new CsrfToken("",parameterName, tokenValue); + new DefaultCsrfToken("",parameterName, tokenValue); } @Test(expected = IllegalArgumentException.class) public void constructorNullParameterName() { - new CsrfToken(headerName,null, tokenValue); + new DefaultCsrfToken(headerName,null, tokenValue); } @Test(expected = IllegalArgumentException.class) public void constructorEmptyParameterName() { - new CsrfToken(headerName,"", tokenValue); + new DefaultCsrfToken(headerName,"", tokenValue); } @Test(expected = IllegalArgumentException.class) public void constructorNullTokenValue() { - new CsrfToken(headerName,parameterName, null); + new DefaultCsrfToken(headerName,parameterName, null); } @Test(expected = IllegalArgumentException.class) public void constructorEmptyTokenValue() { - new CsrfToken(headerName,parameterName, ""); + new DefaultCsrfToken(headerName,parameterName, ""); } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java index d450452fe3..2820f90d83 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java @@ -42,23 +42,23 @@ public class HttpSessionCsrfTokenRepositoryTests { } @Test - public void generateAndSaveToken() { - token = repo.generateAndSaveToken(request, response); + public void generateToken() { + token = repo.generateToken(request); assertThat(token.getParameterName()).isEqualTo("_csrf"); assertThat(token.getToken()).isNotEmpty(); CsrfToken loadedToken = repo.loadToken(request); - assertThat(loadedToken).isEqualTo(token); + assertThat(loadedToken).isNull(); } @Test - public void generateAndSaveTokenCustomParameter() { + public void generateCustomParameter() { String paramName = "_csrf"; repo.setParameterName(paramName); - token = repo.generateAndSaveToken(request, response); + token = repo.generateToken(request); assertThat(token.getParameterName()).isEqualTo(paramName); assertThat(token.getToken()).isNotEmpty(); @@ -71,7 +71,7 @@ public class HttpSessionCsrfTokenRepositoryTests { @Test public void saveToken() { - CsrfToken tokenToSave = new CsrfToken("123", "abc", "def"); + CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def"); repo.saveToken(tokenToSave, request, response); String attrName = request.getSession().getAttributeNames() @@ -84,7 +84,7 @@ public class HttpSessionCsrfTokenRepositoryTests { @Test public void saveTokenCustomSessionAttribute() { - CsrfToken tokenToSave = new CsrfToken("123", "abc", "def"); + CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def"); String sessionAttributeName = "custom"; repo.setSessionAttributeName(sessionAttributeName); repo.saveToken(tokenToSave, request, response); diff --git a/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java b/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java index 98ced2d165..525c129919 100644 --- a/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java @@ -25,6 +25,7 @@ import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; /** * @author Rob Winch @@ -51,7 +52,7 @@ public class CsrfRequestDataValueProcessorTests { @Test public void getExtraHiddenFieldsHasCsrfToken() { - CsrfToken token = new CsrfToken("1", "a", "b"); + CsrfToken token = new DefaultCsrfToken("1", "a", "b"); request.setAttribute(CsrfToken.class.getName(), token); Map expected = new HashMap(); expected.put(token.getParameterName(),token.getToken());