Defer CsrfFilter Session Access

Closes gh-11456
This commit is contained in:
Rob Winch 2022-08-16 11:31:27 -05:00
commit c1a6cea60a
11 changed files with 185 additions and 1 deletions

View File

@ -89,6 +89,8 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
private SessionAuthenticationStrategy sessionAuthenticationStrategy; private SessionAuthenticationStrategy sessionAuthenticationStrategy;
private String csrfRequestAttributeName;
private final ApplicationContext context; private final ApplicationContext context;
/** /**
@ -124,6 +126,16 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
return this; return this;
} }
/**
* Sets the {@link CsrfFilter#setCsrfRequestAttributeName(String)}
* @param csrfRequestAttributeName the attribute name to set the CsrfToken on.
* @return the {@link CsrfConfigurer} for further customizations.
*/
public CsrfConfigurer<H> csrfRequestAttributeName(String csrfRequestAttributeName) {
this.csrfRequestAttributeName = csrfRequestAttributeName;
return this;
}
/** /**
* <p> * <p>
* Allows specifying {@link HttpServletRequest} that should not use CSRF Protection * Allows specifying {@link HttpServletRequest} that should not use CSRF Protection
@ -202,6 +214,9 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
@Override @Override
public void configure(H http) { public void configure(H http) {
CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository); CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository);
if (this.csrfRequestAttributeName != null) {
filter.setCsrfRequestAttributeName(this.csrfRequestAttributeName);
}
RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher(); RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
if (requireCsrfProtectionMatcher != null) { if (requireCsrfProtectionMatcher != null) {
filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);

View File

@ -67,10 +67,14 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
private static final String DISPATCHER_SERVLET_CLASS_NAME = "org.springframework.web.servlet.DispatcherServlet"; private static final String DISPATCHER_SERVLET_CLASS_NAME = "org.springframework.web.servlet.DispatcherServlet";
private static final String ATT_REQUEST_ATTRIBUTE_NAME = "request-attribute-name";
private static final String ATT_MATCHER = "request-matcher-ref"; private static final String ATT_MATCHER = "request-matcher-ref";
private static final String ATT_REPOSITORY = "token-repository-ref"; private static final String ATT_REPOSITORY = "token-repository-ref";
private String requestAttributeName;
private String csrfRepositoryRef; private String csrfRepositoryRef;
private BeanDefinition csrfFilter; private BeanDefinition csrfFilter;
@ -94,6 +98,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
} }
if (element != null) { if (element != null) {
this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY); this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY);
this.requestAttributeName = element.getAttribute(ATT_REQUEST_ATTRIBUTE_NAME);
this.requestMatcherRef = element.getAttribute(ATT_MATCHER); this.requestMatcherRef = element.getAttribute(ATT_MATCHER);
} }
if (!StringUtils.hasText(this.csrfRepositoryRef)) { if (!StringUtils.hasText(this.csrfRepositoryRef)) {
@ -110,6 +115,9 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
if (StringUtils.hasText(this.requestMatcherRef)) { if (StringUtils.hasText(this.requestMatcherRef)) {
builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef); builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef);
} }
if (StringUtils.hasText(this.requestAttributeName)) {
builder.addPropertyValue("csrfRequestAttributeName", this.requestAttributeName);
}
this.csrfFilter = builder.getBeanDefinition(); this.csrfFilter = builder.getBeanDefinition();
return this.csrfFilter; return this.csrfFilter;
} }

View File

@ -1136,6 +1136,9 @@ csrf =
csrf-options.attlist &= csrf-options.attlist &=
## Specifies if csrf protection should be disabled. Default false (i.e. CSRF protection is enabled). ## Specifies if csrf protection should be disabled. Default false (i.e. CSRF protection is enabled).
attribute disabled {xsd:boolean}? attribute disabled {xsd:boolean}?
csrf-options.attlist &=
## The request attribute name the CsrfToken is set on. Default is to set to CsrfToken.parameterName
attribute request-attribute-name { xsd:token }?
csrf-options.attlist &= csrf-options.attlist &=
## The RequestMatcher instance to be used to determine if CSRF should be applied. Default is any HTTP method except "GET", "TRACE", "HEAD", "OPTIONS" ## The RequestMatcher instance to be used to determine if CSRF should be applied. Default is any HTTP method except "GET", "TRACE", "HEAD", "OPTIONS"
attribute request-matcher-ref { xsd:token }? attribute request-matcher-ref { xsd:token }?

View File

@ -3217,6 +3217,13 @@
</xs:documentation> </xs:documentation>
</xs:annotation> </xs:annotation>
</xs:attribute> </xs:attribute>
<xs:attribute name="request-attribute-name" type="xs:token">
<xs:annotation>
<xs:documentation>The request attribute name the CsrfToken is set on. Default is to set to
CsrfToken.parameterName
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="request-matcher-ref" type="xs:token"> <xs:attribute name="request-matcher-ref" type="xs:token">
<xs:annotation> <xs:annotation>
<xs:documentation>The RequestMatcher instance to be used to determine if CSRF should be applied. Default is <xs:documentation>The RequestMatcher instance to be used to determine if CSRF should be applied. Default is

View File

@ -291,6 +291,15 @@ public class CsrfConfigTests {
// @formatter:on // @formatter:on
} }
@Test
public void getWhenUsingCsrfAndCustomRequestAttributeThenSetUsingCsrfAttrName() throws Exception {
this.spring.configLocations(this.xml("WithRequestAttrName")).autowire();
// @formatter:off
MvcResult result = this.mvc.perform(get("/ok")).andReturn();
assertThat(result.getRequest().getAttribute("csrf-attribute-name")).isInstanceOf(CsrfToken.class);
// @formatter:on
}
@Test @Test
public void postWhenHasCsrfTokenButSessionExpiresThenRequestIsCancelledAfterSuccessfulAuthentication() public void postWhenHasCsrfTokenButSessionExpiresThenRequestIsCancelledAfterSuccessfulAuthentication()
throws Exception { throws Exception {

View File

@ -0,0 +1,29 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright 2002-2018 the original author or authors.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ https://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<b:beans xmlns:b="http://www.springframework.org/schema/beans"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://www.springframework.org/schema/security"
xsi:schemaLocation="http://www.springframework.org/schema/security https://www.springframework.org/schema/security/spring-security.xsd
http://www.springframework.org/schema/beans https://www.springframework.org/schema/beans/spring-beans.xsd">
<http auto-config="true">
<csrf request-attribute-name="csrf-attribute-name"/>
</http>
<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
</b:beans>

View File

@ -775,6 +775,10 @@ It is highly recommended to leave CSRF protection enabled.
The CsrfTokenRepository to use. The CsrfTokenRepository to use.
The default is `HttpSessionCsrfTokenRepository`. The default is `HttpSessionCsrfTokenRepository`.
[[nsa-csrf-request-attribute-name]]
* **request-attribute-name**
Optional attribute that specifies the request attribute name to set the `CsrfToken` on.
The default is `CsrfToken.parameterName`.
[[nsa-csrf-request-matcher-ref]] [[nsa-csrf-request-matcher-ref]]
* **request-matcher-ref** * **request-matcher-ref**

View File

@ -87,6 +87,8 @@ public final class CsrfFilter extends OncePerRequestFilter {
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
private String csrfRequestAttributeName;
public CsrfFilter(CsrfTokenRepository csrfTokenRepository) { public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
this.tokenRepository = csrfTokenRepository; this.tokenRepository = csrfTokenRepository;
@ -108,7 +110,9 @@ public final class CsrfFilter extends OncePerRequestFilter {
this.tokenRepository.saveToken(csrfToken, request, response); this.tokenRepository.saveToken(csrfToken, request, response);
} }
request.setAttribute(CsrfToken.class.getName(), csrfToken); request.setAttribute(CsrfToken.class.getName(), csrfToken);
request.setAttribute(csrfToken.getParameterName(), csrfToken); String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName
: csrfToken.getParameterName();
request.setAttribute(csrfAttrName, csrfToken);
if (!this.requireCsrfProtectionMatcher.matches(request)) { if (!this.requireCsrfProtectionMatcher.matches(request)) {
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Did not protect against CSRF since request did not match " this.logger.trace("Did not protect against CSRF since request did not match "
@ -167,6 +171,18 @@ public final class CsrfFilter extends OncePerRequestFilter {
this.accessDeniedHandler = accessDeniedHandler; this.accessDeniedHandler = accessDeniedHandler;
} }
/**
* The {@link CsrfToken} is available as a request attribute named
* {@code CsrfToken.class.getName()}. By default, an additional request attribute that
* is the same as {@link CsrfToken#getParameterName()} is set. This attribute allows
* overriding the additional attribute.
* @param csrfRequestAttributeName the name of an additional request attribute with
* the value of the CsrfToken. Default is {@link CsrfToken#getParameterName()}
*/
public void setCsrfRequestAttributeName(String csrfRequestAttributeName) {
this.csrfRequestAttributeName = csrfRequestAttributeName;
}
/** /**
* Constant time comparison to prevent against timing attacks. * Constant time comparison to prevent against timing attacks.
* @param expected * @param expected

View File

@ -38,6 +38,8 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
private final CsrfTokenRepository delegate; private final CsrfTokenRepository delegate;
private boolean deferLoadToken;
/** /**
* Creates a new instance * Creates a new instance
* @param delegate the {@link CsrfTokenRepository} to use. Cannot be null * @param delegate the {@link CsrfTokenRepository} to use. Cannot be null
@ -48,6 +50,15 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
this.delegate = delegate; this.delegate = delegate;
} }
/**
* Determines if {@link #loadToken(HttpServletRequest)} should be lazily loaded.
* @param deferLoadToken true if should lazily load
* {@link #loadToken(HttpServletRequest)}. Default false.
*/
public void setDeferLoadToken(boolean deferLoadToken) {
this.deferLoadToken = deferLoadToken;
}
/** /**
* Generates a new token * Generates a new token
* @param request the {@link HttpServletRequest} to use. The * @param request the {@link HttpServletRequest} to use. The
@ -77,6 +88,9 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
*/ */
@Override @Override
public CsrfToken loadToken(HttpServletRequest request) { public CsrfToken loadToken(HttpServletRequest request) {
if (this.deferLoadToken) {
return new LazyLoadCsrfToken(request, this.delegate);
}
return this.delegate.loadToken(request); return this.delegate.loadToken(request);
} }
@ -92,6 +106,55 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
return response; return response;
} }
private final class LazyLoadCsrfToken implements CsrfToken {
private final HttpServletRequest request;
private final CsrfTokenRepository tokenRepository;
private CsrfToken token;
private LazyLoadCsrfToken(HttpServletRequest request, CsrfTokenRepository tokenRepository) {
this.request = request;
this.tokenRepository = tokenRepository;
}
private CsrfToken getDelegate() {
if (this.token != null) {
return this.token;
}
// load from the delegate repository
this.token = LazyCsrfTokenRepository.this.delegate.loadToken(this.request);
if (this.token == null) {
// return a generated token that is lazily saved since
// LazyCsrfTokenRepository#loadToken always returns a value
this.token = generateToken(this.request);
}
return this.token;
}
@Override
public String getHeaderName() {
return getDelegate().getHeaderName();
}
@Override
public String getParameterName() {
return getDelegate().getParameterName();
}
@Override
public String getToken() {
return getDelegate().getToken();
}
@Override
public String toString() {
return "LazyLoadCsrfToken{" + "token=" + this.token + '}';
}
}
private static final class SaveOnAccessCsrfToken implements CsrfToken { private static final class SaveOnAccessCsrfToken implements CsrfToken {
private transient CsrfTokenRepository tokenRepository; private transient CsrfTokenRepository tokenRepository;

View File

@ -48,6 +48,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
/** /**
@ -344,6 +345,23 @@ public class CsrfFilterTests {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAccessDeniedHandler(null)); assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAccessDeniedHandler(null));
} }
// This ensures that the HttpSession on get requests unless the CsrfToken is used
@Test
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
throws ServletException, IOException {
CsrfFilter filter = createCsrfFilter(this.tokenRepository);
String csrfAttrName = "_csrf";
filter.setCsrfRequestAttributeName(csrfAttrName);
CsrfToken expectedCsrfToken = mock(CsrfToken.class);
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
filter.doFilter(this.request, this.response, this.filterChain);
verifyNoInteractions(expectedCsrfToken);
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
assertThat(tokenFromRequest).isEqualTo(expectedCsrfToken);
}
private static CsrfTokenAssert assertToken(Object token) { private static CsrfTokenAssert assertToken(Object token) {
return new CsrfTokenAssert((CsrfToken) token); return new CsrfTokenAssert((CsrfToken) token);
} }

View File

@ -31,6 +31,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
/** /**
@ -98,4 +99,15 @@ public class LazyCsrfTokenRepositoryTests {
verify(this.delegate).loadToken(this.request); verify(this.delegate).loadToken(this.request);
} }
@Test
public void loadTokenWhenDeferLoadToken() {
given(this.delegate.loadToken(this.request)).willReturn(this.token);
this.repository.setDeferLoadToken(true);
CsrfToken loadToken = this.repository.loadToken(this.request);
verifyNoInteractions(this.delegate);
assertThat(loadToken.getToken()).isEqualTo(this.token.getToken());
assertThat(loadToken.getHeaderName()).isEqualTo(this.token.getHeaderName());
assertThat(loadToken.getParameterName()).isEqualTo(this.token.getParameterName());
}
} }