diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java index b9c5cc7c63..3831190e90 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java @@ -89,6 +89,8 @@ public final class CsrfConfigurer> private SessionAuthenticationStrategy sessionAuthenticationStrategy; + private String csrfRequestAttributeName; + private final ApplicationContext context; /** @@ -124,6 +126,16 @@ public final class CsrfConfigurer> return this; } + /** + * Sets the {@link CsrfFilter#setCsrfRequestAttributeName(String)} + * @param csrfRequestAttributeName the attribute name to set the CsrfToken on. + * @return the {@link CsrfConfigurer} for further customizations. + */ + public CsrfConfigurer csrfRequestAttributeName(String csrfRequestAttributeName) { + this.csrfRequestAttributeName = csrfRequestAttributeName; + return this; + } + /** *

* Allows specifying {@link HttpServletRequest} that should not use CSRF Protection @@ -202,6 +214,9 @@ public final class CsrfConfigurer> @Override public void configure(H http) { CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository); + if (this.csrfRequestAttributeName != null) { + filter.setCsrfRequestAttributeName(this.csrfRequestAttributeName); + } RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher(); if (requireCsrfProtectionMatcher != null) { filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); diff --git a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java index 58dcd468a8..495a2ddd2a 100644 --- a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java @@ -67,10 +67,14 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { private static final String DISPATCHER_SERVLET_CLASS_NAME = "org.springframework.web.servlet.DispatcherServlet"; + private static final String ATT_REQUEST_ATTRIBUTE_NAME = "request-attribute-name"; + private static final String ATT_MATCHER = "request-matcher-ref"; private static final String ATT_REPOSITORY = "token-repository-ref"; + private String requestAttributeName; + private String csrfRepositoryRef; private BeanDefinition csrfFilter; @@ -94,6 +98,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { } if (element != null) { this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY); + this.requestAttributeName = element.getAttribute(ATT_REQUEST_ATTRIBUTE_NAME); this.requestMatcherRef = element.getAttribute(ATT_MATCHER); } if (!StringUtils.hasText(this.csrfRepositoryRef)) { @@ -110,6 +115,9 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { if (StringUtils.hasText(this.requestMatcherRef)) { builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef); } + if (StringUtils.hasText(this.requestAttributeName)) { + builder.addPropertyValue("csrfRequestAttributeName", this.requestAttributeName); + } this.csrfFilter = builder.getBeanDefinition(); return this.csrfFilter; } diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc index 36afa1b42b..6104ee7cc5 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc @@ -1136,6 +1136,9 @@ csrf = csrf-options.attlist &= ## Specifies if csrf protection should be disabled. Default false (i.e. CSRF protection is enabled). attribute disabled {xsd:boolean}? +csrf-options.attlist &= + ## The request attribute name the CsrfToken is set on. Default is to set to CsrfToken.parameterName + attribute request-attribute-name { xsd:token }? csrf-options.attlist &= ## The RequestMatcher instance to be used to determine if CSRF should be applied. Default is any HTTP method except "GET", "TRACE", "HEAD", "OPTIONS" attribute request-matcher-ref { xsd:token }? diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd index 2b98c8e68d..4255a1ae11 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd @@ -3217,6 +3217,13 @@ + + + The request attribute name the CsrfToken is set on. Default is to set to + CsrfToken.parameterName + + + The RequestMatcher instance to be used to determine if CSRF should be applied. Default is diff --git a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java index ce17aebd9a..a62bfae26d 100644 --- a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java @@ -291,6 +291,15 @@ public class CsrfConfigTests { // @formatter:on } + @Test + public void getWhenUsingCsrfAndCustomRequestAttributeThenSetUsingCsrfAttrName() throws Exception { + this.spring.configLocations(this.xml("WithRequestAttrName")).autowire(); + // @formatter:off + MvcResult result = this.mvc.perform(get("/ok")).andReturn(); + assertThat(result.getRequest().getAttribute("csrf-attribute-name")).isInstanceOf(CsrfToken.class); + // @formatter:on + } + @Test public void postWhenHasCsrfTokenButSessionExpiresThenRequestIsCancelledAfterSuccessfulAuthentication() throws Exception { diff --git a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml new file mode 100644 index 0000000000..4f6c27248b --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml @@ -0,0 +1,29 @@ + + + + + + + + + + + diff --git a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc index f707ca0453..1bd4b70907 100644 --- a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc +++ b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc @@ -775,6 +775,10 @@ It is highly recommended to leave CSRF protection enabled. The CsrfTokenRepository to use. The default is `HttpSessionCsrfTokenRepository`. +[[nsa-csrf-request-attribute-name]] +* **request-attribute-name** +Optional attribute that specifies the request attribute name to set the `CsrfToken` on. +The default is `CsrfToken.parameterName`. [[nsa-csrf-request-matcher-ref]] * **request-matcher-ref** diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 8490c508da..1c452ecf45 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -87,6 +87,8 @@ public final class CsrfFilter extends OncePerRequestFilter { private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); + private String csrfRequestAttributeName; + public CsrfFilter(CsrfTokenRepository csrfTokenRepository) { Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); this.tokenRepository = csrfTokenRepository; @@ -108,7 +110,9 @@ public final class CsrfFilter extends OncePerRequestFilter { this.tokenRepository.saveToken(csrfToken, request, response); } request.setAttribute(CsrfToken.class.getName(), csrfToken); - request.setAttribute(csrfToken.getParameterName(), csrfToken); + String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName + : csrfToken.getParameterName(); + request.setAttribute(csrfAttrName, csrfToken); if (!this.requireCsrfProtectionMatcher.matches(request)) { if (this.logger.isTraceEnabled()) { this.logger.trace("Did not protect against CSRF since request did not match " @@ -167,6 +171,18 @@ public final class CsrfFilter extends OncePerRequestFilter { this.accessDeniedHandler = accessDeniedHandler; } + /** + * The {@link CsrfToken} is available as a request attribute named + * {@code CsrfToken.class.getName()}. By default, an additional request attribute that + * is the same as {@link CsrfToken#getParameterName()} is set. This attribute allows + * overriding the additional attribute. + * @param csrfRequestAttributeName the name of an additional request attribute with + * the value of the CsrfToken. Default is {@link CsrfToken#getParameterName()} + */ + public void setCsrfRequestAttributeName(String csrfRequestAttributeName) { + this.csrfRequestAttributeName = csrfRequestAttributeName; + } + /** * Constant time comparison to prevent against timing attacks. * @param expected diff --git a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java index b0f4263f2a..a9817cc9c7 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java @@ -38,6 +38,8 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { private final CsrfTokenRepository delegate; + private boolean deferLoadToken; + /** * Creates a new instance * @param delegate the {@link CsrfTokenRepository} to use. Cannot be null @@ -48,6 +50,15 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { this.delegate = delegate; } + /** + * Determines if {@link #loadToken(HttpServletRequest)} should be lazily loaded. + * @param deferLoadToken true if should lazily load + * {@link #loadToken(HttpServletRequest)}. Default false. + */ + public void setDeferLoadToken(boolean deferLoadToken) { + this.deferLoadToken = deferLoadToken; + } + /** * Generates a new token * @param request the {@link HttpServletRequest} to use. The @@ -77,6 +88,9 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { */ @Override public CsrfToken loadToken(HttpServletRequest request) { + if (this.deferLoadToken) { + return new LazyLoadCsrfToken(request, this.delegate); + } return this.delegate.loadToken(request); } @@ -92,6 +106,55 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { return response; } + private final class LazyLoadCsrfToken implements CsrfToken { + + private final HttpServletRequest request; + + private final CsrfTokenRepository tokenRepository; + + private CsrfToken token; + + private LazyLoadCsrfToken(HttpServletRequest request, CsrfTokenRepository tokenRepository) { + this.request = request; + this.tokenRepository = tokenRepository; + } + + private CsrfToken getDelegate() { + if (this.token != null) { + return this.token; + } + // load from the delegate repository + this.token = LazyCsrfTokenRepository.this.delegate.loadToken(this.request); + if (this.token == null) { + // return a generated token that is lazily saved since + // LazyCsrfTokenRepository#loadToken always returns a value + this.token = generateToken(this.request); + } + return this.token; + } + + @Override + public String getHeaderName() { + return getDelegate().getHeaderName(); + } + + @Override + public String getParameterName() { + return getDelegate().getParameterName(); + } + + @Override + public String getToken() { + return getDelegate().getToken(); + } + + @Override + public String toString() { + return "LazyLoadCsrfToken{" + "token=" + this.token + '}'; + } + + } + private static final class SaveOnAccessCsrfToken implements CsrfToken { private transient CsrfTokenRepository tokenRepository; diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index b4244f228d..847f1b85da 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -48,6 +48,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyZeroInteractions; /** @@ -344,6 +345,23 @@ public class CsrfFilterTests { assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAccessDeniedHandler(null)); } + // This ensures that the HttpSession on get requests unless the CsrfToken is used + @Test + public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet() + throws ServletException, IOException { + CsrfFilter filter = createCsrfFilter(this.tokenRepository); + String csrfAttrName = "_csrf"; + filter.setCsrfRequestAttributeName(csrfAttrName); + CsrfToken expectedCsrfToken = mock(CsrfToken.class); + given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken); + + filter.doFilter(this.request, this.response, this.filterChain); + + verifyNoInteractions(expectedCsrfToken); + CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); + assertThat(tokenFromRequest).isEqualTo(expectedCsrfToken); + } + private static CsrfTokenAssert assertToken(Object token) { return new CsrfTokenAssert((CsrfToken) token); } diff --git a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java index 5ad35258d4..35d57634af 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java @@ -31,6 +31,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyZeroInteractions; /** @@ -98,4 +99,15 @@ public class LazyCsrfTokenRepositoryTests { verify(this.delegate).loadToken(this.request); } + @Test + public void loadTokenWhenDeferLoadToken() { + given(this.delegate.loadToken(this.request)).willReturn(this.token); + this.repository.setDeferLoadToken(true); + CsrfToken loadToken = this.repository.loadToken(this.request); + verifyNoInteractions(this.delegate); + assertThat(loadToken.getToken()).isEqualTo(this.token.getToken()); + assertThat(loadToken.getHeaderName()).isEqualTo(this.token.getHeaderName()); + assertThat(loadToken.getParameterName()).isEqualTo(this.token.getParameterName()); + } + }