diff --git a/config/src/test/groovy/org/springframework/security/config/annotation/BaseSpringSpec.groovy b/config/src/test/groovy/org/springframework/security/config/annotation/BaseSpringSpec.groovy index a29320da2b..489501e0e9 100644 --- a/config/src/test/groovy/org/springframework/security/config/annotation/BaseSpringSpec.groovy +++ b/config/src/test/groovy/org/springframework/security/config/annotation/BaseSpringSpec.groovy @@ -36,6 +36,7 @@ import org.springframework.security.web.access.intercept.FilterSecurityIntercept import org.springframework.security.web.context.HttpRequestResponseHolder import org.springframework.security.web.context.HttpSessionSecurityContextRepository import org.springframework.security.web.csrf.CsrfToken +import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository import spock.lang.AutoCleanup @@ -69,7 +70,7 @@ abstract class BaseSpringSpec extends Specification { } def setupCsrf(csrfTokenValue="BaseSpringSpec_CSRFTOKEN") { - csrfToken = new CsrfToken("X-CSRF-TOKEN","_csrf",csrfTokenValue) + csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf",csrfTokenValue) new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request,response) request.setParameter(csrfToken.parameterName, csrfToken.token) } diff --git a/config/src/test/groovy/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.groovy b/config/src/test/groovy/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.groovy index 32a6e40f53..6e15ea0cc9 100644 --- a/config/src/test/groovy/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.groovy +++ b/config/src/test/groovy/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.groovy @@ -79,8 +79,7 @@ class WebSecurityConfigurerAdapterTests extends BaseSpringSpec { 'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains', 'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate', 'Pragma':'no-cache', - 'X-XSS-Protection' : '1; mode=block', - 'X-CSRF-TOKEN' : csrfToken.token] + 'X-XSS-Protection' : '1; mode=block'] } @EnableWebSecurity diff --git a/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.groovy b/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.groovy index e5d4a37258..66595012b8 100644 --- a/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.groovy +++ b/config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.groovy @@ -49,8 +49,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { 'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains', 'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate', 'Pragma':'no-cache', - 'X-XSS-Protection' : '1; mode=block', - 'X-CSRF-TOKEN' : csrfToken.token] + 'X-XSS-Protection' : '1; mode=block'] } @Configuration @@ -70,8 +69,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { springSecurityFilterChain.doFilter(request,response,chain) then: responseHeaders == ['Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate', - 'Pragma':'no-cache', - 'X-CSRF-TOKEN' : csrfToken.token] + 'Pragma':'no-cache'] } @Configuration @@ -91,8 +89,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains'] } @Configuration @@ -111,8 +108,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['Strict-Transport-Security': 'max-age=15768000', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['Strict-Transport-Security': 'max-age=15768000'] } @Configuration @@ -133,8 +129,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['X-Frame-Options': 'SAMEORIGIN', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['X-Frame-Options': 'SAMEORIGIN'] } @Configuration @@ -156,8 +151,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['X-Frame-Options': 'ALLOW-FROM https://example.com', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['X-Frame-Options': 'ALLOW-FROM https://example.com'] } @@ -178,8 +172,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['X-XSS-Protection': '1; mode=block', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['X-XSS-Protection': '1; mode=block'] } @Configuration @@ -199,8 +192,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['X-XSS-Protection': '1', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['X-XSS-Protection': '1'] } @Configuration @@ -220,8 +212,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['X-Content-Type-Options': 'nosniff', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['X-Content-Type-Options': 'nosniff'] } @Configuration @@ -243,8 +234,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec { when: springSecurityFilterChain.doFilter(request,response,chain) then: - responseHeaders == ['customHeaderName': 'customHeaderValue', - 'X-CSRF-TOKEN' : csrfToken.token] + responseHeaders == ['customHeaderName': 'customHeaderValue'] } @Configuration diff --git a/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy b/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy index 49d5c6c80c..72693e700a 100644 --- a/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy +++ b/config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy @@ -29,6 +29,7 @@ import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.csrf.CsrfFilter import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor import org.springframework.security.web.util.RequestMatcher @@ -113,7 +114,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests { mockBean(CsrfTokenRepository,'repo') createAppContext() CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc") + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token) request.servletPath = "/some-url" @@ -147,7 +148,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests { mockBean(CsrfTokenRepository,'repo') createAppContext() CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc") + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token) request.servletPath = "/some-url" @@ -200,7 +201,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests { mockBean(CsrfTokenRepository,'repo') createAppContext() CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc") + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token) request.method = "POST" @@ -223,7 +224,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests { mockBean(CsrfTokenRepository,'repo') createAppContext() CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc") + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token) request.method = "POST" @@ -244,7 +245,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests { mockBean(CsrfTokenRepository,'repo') createAppContext() CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository) - CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc") + CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc") when(repo.loadToken(any(HttpServletRequest))).thenReturn(token) request.setParameter(token.parameterName,token.token) request.method = "POST" diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java index 6f834abd76..e090e914e7 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java @@ -40,7 +40,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -96,7 +95,9 @@ public class SessionManagementConfigurerServlet31Tests { request.setMethod("POST"); request.setParameter("username", "user"); request.setParameter("password", "password"); - CsrfToken token = new HttpSessionCsrfTokenRepository().generateAndSaveToken(request, response); + HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); + CsrfToken token = repository.generateToken(request); + repository.saveToken(token, request, response); request.setParameter(token.getParameterName(),token.getToken()); when(ReflectionUtils.findMethod(HttpServletRequest.class, "changeSessionId")).thenReturn(method); diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 5851dc7c13..e6de1f4b36 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -70,11 +70,11 @@ public final class CsrfFilter extends OncePerRequestFilter { throws ServletException, IOException { CsrfToken csrfToken = tokenRepository.loadToken(request); if(csrfToken == null) { - csrfToken = tokenRepository.generateAndSaveToken(request, response); + CsrfToken generatedToken = tokenRepository.generateToken(request); + csrfToken = new SaveOnAccessCsrfToken(tokenRepository, request, response, generatedToken); } request.setAttribute(CsrfToken.class.getName(), csrfToken); request.setAttribute(csrfToken.getParameterName(), csrfToken); - response.addHeader(csrfToken.getHeaderName(), csrfToken.getToken()); if(!requireCsrfProtectionMatcher.matches(request)) { filterChain.doFilter(request, response); @@ -128,7 +128,86 @@ public final class CsrfFilter extends OncePerRequestFilter { this.accessDeniedHandler = accessDeniedHandler; } - private static class DefaultRequiresCsrfMatcher implements RequestMatcher { + @SuppressWarnings("serial") + private static final class SaveOnAccessCsrfToken implements CsrfToken { + private transient CsrfTokenRepository tokenRepository; + private transient HttpServletRequest request; + private transient HttpServletResponse response; + + private final CsrfToken delegate; + + public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository, + HttpServletRequest request, HttpServletResponse response, + CsrfToken delegate) { + super(); + this.tokenRepository = tokenRepository; + this.request = request; + this.response = response; + this.delegate = delegate; + } + + public String getHeaderName() { + return delegate.getHeaderName(); + } + + public String getParameterName() { + return delegate.getParameterName(); + } + + public String getToken() { + saveTokenIfNecessary(); + return delegate.getToken(); + } + + @Override + public String toString() { + return "SaveOnAccessCsrfToken [delegate=" + delegate + "]"; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + + ((delegate == null) ? 0 : delegate.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj; + if (delegate == null) { + if (other.delegate != null) + return false; + } else if (!delegate.equals(other.delegate)) + return false; + return true; + } + + private void saveTokenIfNecessary() { + if(this.tokenRepository == null) { + return; + } + + synchronized(this) { + if(tokenRepository != null) { + this.tokenRepository.saveToken(delegate, request, response); + this.tokenRepository = null; + this.request = null; + this.response = null; + } + } + } + + } + + private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { private Pattern allowedMethods = Pattern.compile("^(GET|HEAD|TRACE|OPTIONS)$"); /* (non-Javadoc) diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java index 46f680d5b7..1f98c405e8 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java @@ -17,37 +17,16 @@ package org.springframework.security.web.csrf; import java.io.Serializable; -import org.springframework.util.Assert; - /** - * A CSRF token that is used to protect against CSRF attacks. + * Provides the information about an expected CSRF token. + * + * @see DefaultCsrfToken * * @author Rob Winch * @since 3.2 + * */ -@SuppressWarnings("serial") -public final class CsrfToken implements Serializable { - - private final String token; - - private final String parameterName; - - private final String headerName; - - /** - * Creates a new instance - * @param headerName the HTTP header name to use - * @param parameterName the HTTP parameter name to use - * @param token the value of the token (i.e. expected value of the HTTP parameter of parametername). - */ - public CsrfToken(String headerName, String parameterName, String token) { - Assert.hasLength(headerName, "headerName cannot be null or empty"); - Assert.hasLength(parameterName, "parameterName cannot be null or empty"); - Assert.hasLength(token, "token cannot be null or empty"); - this.headerName = headerName; - this.parameterName = parameterName; - this.token = token; - } +public interface CsrfToken extends Serializable { /** * Gets the HTTP header that the CSRF is populated on the response and can @@ -56,23 +35,18 @@ public final class CsrfToken implements Serializable { * @return the HTTP header that the CSRF is populated on the response and * can be placed on requests instead of the parameter */ - public String getHeaderName() { - return headerName; - } + String getHeaderName(); /** * Gets the HTTP parameter name that should contain the token. Cannot be null. * @return the HTTP parameter name that should contain the token. */ - public String getParameterName() { - return parameterName; - } + String getParameterName(); /** * Gets the token value. Cannot be null. * @return the token value */ - public String getToken() { - return token; - } -} + String getToken(); + +} \ No newline at end of file diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java index 4b705fd27c..36a8109a7d 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java @@ -33,17 +33,14 @@ import javax.servlet.http.HttpSession; public interface CsrfTokenRepository { /** - * Generates and saves the expected {@link CsrfToken} + * Generates a {@link CsrfToken} * * @param request * the {@link HttpServletRequest} to use - * @param response - * the {@link HttpServletResponse} to use - * @return the {@link CsrfToken} that was generated and saved. Cannot be + * @return the {@link CsrfToken} that was generated. Cannot be * null. */ - CsrfToken generateAndSaveToken(HttpServletRequest request, - HttpServletResponse response); + CsrfToken generateToken(HttpServletRequest request); /** * Saves the {@link CsrfToken} using the {@link HttpServletRequest} and diff --git a/web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java new file mode 100644 index 0000000000..8c9dc20dbb --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.csrf; + +import org.springframework.util.Assert; + +/** + * A CSRF token that is used to protect against CSRF attacks. + * + * @author Rob Winch + * @since 3.2 + */ +@SuppressWarnings("serial") +public final class DefaultCsrfToken implements CsrfToken { + + private final String token; + + private final String parameterName; + + private final String headerName; + + /** + * Creates a new instance + * @param headerName the HTTP header name to use + * @param parameterName the HTTP parameter name to use + * @param token the value of the token (i.e. expected value of the HTTP parameter of parametername). + */ + public DefaultCsrfToken(String headerName, String parameterName, String token) { + Assert.hasLength(headerName, "headerName cannot be null or empty"); + Assert.hasLength(parameterName, "parameterName cannot be null or empty"); + Assert.hasLength(token, "token cannot be null or empty"); + this.headerName = headerName; + this.parameterName = parameterName; + this.token = token; + } + + /* (non-Javadoc) + * @see org.springframework.security.web.csrf.CsrfToken#getHeaderName() + */ + public String getHeaderName() { + return headerName; + } + + /* (non-Javadoc) + * @see org.springframework.security.web.csrf.CsrfToken#getParameterName() + */ + public String getParameterName() { + return parameterName; + } + + /* (non-Javadoc) + * @see org.springframework.security.web.csrf.CsrfToken#getToken() + */ + public String getToken() { + return token; + } +} diff --git a/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java index 6eecd01b63..c9927fd91c 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java @@ -63,14 +63,12 @@ public final class HttpSessionCsrfTokenRepository implements CsrfTokenRepository return (CsrfToken) request.getSession().getAttribute(sessionAttributeName); } - /* (non-Javadoc) - * @see org.springframework.security.web.csrf.CsrfTokenRepository#generateNewToken(javax.servlet.http.HttpServletRequest, javax.servlet.http.HttpServletResponse) + /* + * (non-Javadoc) + * @see org.springframework.security.web.csrf.CsrfTokenRepository#generateToken(javax.servlet.http.HttpServletRequest) */ - public CsrfToken generateAndSaveToken(HttpServletRequest request, - HttpServletResponse response) { - CsrfToken token = new CsrfToken(headerName, parameterName, createNewToken()); - saveToken(token, request, response); - return token; + public CsrfToken generateToken(HttpServletRequest request) { + return new DefaultCsrfToken(headerName, parameterName, createNewToken()); } /** diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index fc1986608f..c9d32a68a7 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -18,6 +18,7 @@ package org.springframework.security.web.csrf; import static org.fest.assertions.Assertions.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; @@ -27,8 +28,11 @@ import java.util.Arrays; import javax.servlet.FilterChain; import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.fest.assertions.GenericAssert; +import org.fest.assertions.ObjectAssert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -59,12 +63,12 @@ public class CsrfFilterTests { private MockHttpServletResponse response; private CsrfToken token; - private CsrfFilter filter; @Before public void setup() { - token = new CsrfToken("headerName","paramName", "csrfTokenValue"); + token = new DefaultCsrfToken("headerName", "paramName", + "csrfTokenValue"); resetRequestResponse(); filter = new CsrfFilter(tokenRepository); filter.setRequireCsrfProtectionMatcher(requestMatcher); @@ -81,171 +85,221 @@ public class CsrfFilterTests { new CsrfFilter(null); } + // SEC-2276 @Test - public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { + public void doFilterDoesNotSaveCsrfTokenUntilAccessed() throws ServletException, + IOException { + when(requestMatcher.matches(request)).thenReturn(false); + when(tokenRepository.generateToken(request)).thenReturn(token); + + filter.doFilter(request, response, filterChain); + CsrfToken attrToken = (CsrfToken) request.getAttribute(token.getParameterName()); + + // no CsrfToken should have been saved yet + verify(tokenRepository,times(0)).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(filterChain).doFilter(request, response); + + // access the token + attrToken.getToken(); + + // now the CsrfToken should have been saved + verify(tokenRepository).saveToken(eq(token), any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterAccessDeniedNoTokenPresent() throws ServletException, + IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); - verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class)); + verify(deniedHandler).handle(eq(request), eq(response), + any(InvalidCsrfTokenException.class)); verifyZeroInteractions(filterChain); } @Test - public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { + public void doFilterAccessDeniedIncorrectTokenPresent() + throws ServletException, IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); - request.setParameter(token.getParameterName(), token.getToken()+ " INVALID"); + request.setParameter(token.getParameterName(), token.getToken() + + " INVALID"); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); - verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class)); + verify(deniedHandler).handle(eq(request), eq(response), + any(InvalidCsrfTokenException.class)); verifyZeroInteractions(filterChain); } @Test - public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { + public void doFilterAccessDeniedIncorrectTokenPresentHeader() + throws ServletException, IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); - request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID"); + request.addHeader(token.getHeaderName(), token.getToken() + " INVALID"); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); - verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class)); + verify(deniedHandler).handle(eq(request), eq(response), + any(InvalidCsrfTokenException.class)); verifyZeroInteractions(filterChain); } @Test - public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException { + public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() + throws ServletException, IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); request.setParameter(token.getParameterName(), token.getToken()); - request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID"); + request.addHeader(token.getHeaderName(), token.getToken() + " INVALID"); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); - verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class)); + verify(deniedHandler).handle(eq(request), eq(response), + any(InvalidCsrfTokenException.class)); verifyZeroInteractions(filterChain); } @Test - public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { + public void doFilterNotCsrfRequestExistingToken() throws ServletException, + IOException { when(requestMatcher.matches(request)).thenReturn(false); when(tokenRepository.loadToken(request)).thenReturn(token); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { + public void doFilterNotCsrfRequestGenerateToken() throws ServletException, + IOException { when(requestMatcher.matches(request)).thenReturn(false); - when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token); + when(tokenRepository.generateToken(request)) + .thenReturn(token); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertToken(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { + public void doFilterIsCsrfRequestExistingTokenHeader() + throws ServletException, IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); request.addHeader(token.getHeaderName(), token.getToken()); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException { + public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() + throws ServletException, IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); - request.setParameter(token.getParameterName(), token.getToken()+ " INVALID"); + request.setParameter(token.getParameterName(), token.getToken() + + " INVALID"); request.addHeader(token.getHeaderName(), token.getToken()); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { + public void doFilterIsCsrfRequestExistingToken() throws ServletException, + IOException { when(requestMatcher.matches(request)).thenReturn(true); when(tokenRepository.loadToken(request)).thenReturn(token); request.setParameter(token.getParameterName(), token.getToken()); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { + public void doFilterIsCsrfRequestGenerateToken() throws ServletException, + IOException { when(requestMatcher.matches(request)).thenReturn(true); - when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token); + when(tokenRepository.generateToken(request)) + .thenReturn(token); request.setParameter(token.getParameterName(), token.getToken()); filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertToken(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); verify(filterChain).doFilter(request, response); verifyZeroInteractions(deniedHandler); } @Test - public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { + public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() + throws ServletException, IOException { filter = new CsrfFilter(tokenRepository); filter.setAccessDeniedHandler(deniedHandler); - for(String method : Arrays.asList("GET","TRACE", "OPTIONS", "HEAD")) { + for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { resetRequestResponse(); when(tokenRepository.loadToken(request)).thenReturn(token); request.setMethod(method); @@ -258,24 +312,28 @@ public class CsrfFilterTests { } @Test - public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { + public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() + throws ServletException, IOException { filter = new CsrfFilter(tokenRepository); filter.setAccessDeniedHandler(deniedHandler); - for(String method : Arrays.asList("POST","PUT", "PATCH", "DELETE", "INVALID")) { + for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", + "INVALID")) { resetRequestResponse(); when(tokenRepository.loadToken(request)).thenReturn(token); request.setMethod(method); filter.doFilter(request, response, filterChain); - verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class)); + verify(deniedHandler).handle(eq(request), eq(response), + any(InvalidCsrfTokenException.class)); verifyZeroInteractions(filterChain); } } @Test - public void doFilterDefaultAccessDenied() throws ServletException, IOException { + public void doFilterDefaultAccessDenied() throws ServletException, + IOException { filter = new CsrfFilter(tokenRepository); filter.setRequireCsrfProtectionMatcher(requestMatcher); when(requestMatcher.matches(request)).thenReturn(true); @@ -283,11 +341,13 @@ public class CsrfFilterTests { filter.doFilter(request, response, filterChain); - assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken()); - assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token); - assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token); + assertThat(request.getAttribute(token.getParameterName())).isEqualTo( + token); + assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo( + token); - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + assertThat(response.getStatus()).isEqualTo( + HttpServletResponse.SC_FORBIDDEN); verifyZeroInteractions(filterChain); } @@ -300,4 +360,29 @@ public class CsrfFilterTests { public void setAccessDeniedHandlerNull() { filter.setAccessDeniedHandler(null); } + + private static final CsrfTokenAssert assertToken(Object token) { + return new CsrfTokenAssert((CsrfToken)token); + } + + private static class CsrfTokenAssert extends + GenericAssert { + + /** + * Creates a new {@link ObjectAssert}. + * + * @param actual + * the target to verify. + */ + protected CsrfTokenAssert(CsrfToken actual) { + super(CsrfTokenAssert.class, actual); + } + + public CsrfTokenAssert isEqualTo(CsrfToken expected) { + assertThat(actual.getHeaderName()).isEqualTo(expected.getHeaderName()); + assertThat(actual.getParameterName()).isEqualTo(expected.getParameterName()); + assertThat(actual.getToken()).isEqualTo(expected.getToken()); + return this; + } + } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenTests.java b/web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java similarity index 79% rename from web/src/test/java/org/springframework/security/web/csrf/CsrfTokenTests.java rename to web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java index f3997b5f28..531f1f592a 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java @@ -21,38 +21,38 @@ import org.junit.Test; * @author Rob Winch * */ -public class CsrfTokenTests { +public class DefaultCsrfTokenTests { private final String headerName = "headerName"; private final String parameterName = "parameterName"; private final String tokenValue = "tokenValue"; @Test(expected = IllegalArgumentException.class) public void constructorNullHeaderName() { - new CsrfToken(null,parameterName, tokenValue); + new DefaultCsrfToken(null,parameterName, tokenValue); } @Test(expected = IllegalArgumentException.class) public void constructorEmptyHeaderName() { - new CsrfToken("",parameterName, tokenValue); + new DefaultCsrfToken("",parameterName, tokenValue); } @Test(expected = IllegalArgumentException.class) public void constructorNullParameterName() { - new CsrfToken(headerName,null, tokenValue); + new DefaultCsrfToken(headerName,null, tokenValue); } @Test(expected = IllegalArgumentException.class) public void constructorEmptyParameterName() { - new CsrfToken(headerName,"", tokenValue); + new DefaultCsrfToken(headerName,"", tokenValue); } @Test(expected = IllegalArgumentException.class) public void constructorNullTokenValue() { - new CsrfToken(headerName,parameterName, null); + new DefaultCsrfToken(headerName,parameterName, null); } @Test(expected = IllegalArgumentException.class) public void constructorEmptyTokenValue() { - new CsrfToken(headerName,parameterName, ""); + new DefaultCsrfToken(headerName,parameterName, ""); } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java index d450452fe3..2820f90d83 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java @@ -42,23 +42,23 @@ public class HttpSessionCsrfTokenRepositoryTests { } @Test - public void generateAndSaveToken() { - token = repo.generateAndSaveToken(request, response); + public void generateToken() { + token = repo.generateToken(request); assertThat(token.getParameterName()).isEqualTo("_csrf"); assertThat(token.getToken()).isNotEmpty(); CsrfToken loadedToken = repo.loadToken(request); - assertThat(loadedToken).isEqualTo(token); + assertThat(loadedToken).isNull(); } @Test - public void generateAndSaveTokenCustomParameter() { + public void generateCustomParameter() { String paramName = "_csrf"; repo.setParameterName(paramName); - token = repo.generateAndSaveToken(request, response); + token = repo.generateToken(request); assertThat(token.getParameterName()).isEqualTo(paramName); assertThat(token.getToken()).isNotEmpty(); @@ -71,7 +71,7 @@ public class HttpSessionCsrfTokenRepositoryTests { @Test public void saveToken() { - CsrfToken tokenToSave = new CsrfToken("123", "abc", "def"); + CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def"); repo.saveToken(tokenToSave, request, response); String attrName = request.getSession().getAttributeNames() @@ -84,7 +84,7 @@ public class HttpSessionCsrfTokenRepositoryTests { @Test public void saveTokenCustomSessionAttribute() { - CsrfToken tokenToSave = new CsrfToken("123", "abc", "def"); + CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def"); String sessionAttributeName = "custom"; repo.setSessionAttributeName(sessionAttributeName); repo.saveToken(tokenToSave, request, response); diff --git a/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java b/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java index 98ced2d165..525c129919 100644 --- a/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java @@ -25,6 +25,7 @@ import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; /** * @author Rob Winch @@ -51,7 +52,7 @@ public class CsrfRequestDataValueProcessorTests { @Test public void getExtraHiddenFieldsHasCsrfToken() { - CsrfToken token = new CsrfToken("1", "a", "b"); + CsrfToken token = new DefaultCsrfToken("1", "a", "b"); request.setAttribute(CsrfToken.class.getName(), token); Map expected = new HashMap(); expected.put(token.getParameterName(),token.getToken());