CsrfTokenRequestHandler extends CsrfTokenRequestResolver

Closes gh-11896
This commit is contained in:
Steve Riesenberg 2022-09-23 11:18:31 -05:00
parent d140d95305
commit 46696a9226
No known key found for this signature in database
GPG Key ID: 5F311AB48A55D521
18 changed files with 155 additions and 188 deletions

View File

@ -36,8 +36,8 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestResolver;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
@ -93,8 +93,6 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
private CsrfTokenRequestHandler requestHandler;
private CsrfTokenRequestResolver requestResolver;
private final ApplicationContext context;
/**
@ -135,23 +133,13 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
* available as a request attribute.
* @param requestHandler the {@link CsrfTokenRequestHandler} to use
* @return the {@link CsrfConfigurer} for further customizations
* @since 5.8
*/
public CsrfConfigurer<H> csrfTokenRequestHandler(CsrfTokenRequestHandler requestHandler) {
this.requestHandler = requestHandler;
return this;
}
/**
* Specify a {@link CsrfTokenRequestResolver} to use for resolving the token value
* from the request.
* @param requestResolver the {@link CsrfTokenRequestResolver} to use
* @return the {@link CsrfConfigurer} for further customizations
*/
public CsrfConfigurer<H> csrfTokenRequestResolver(CsrfTokenRequestResolver requestResolver) {
this.requestResolver = requestResolver;
return this;
}
/**
* <p>
* Allows specifying {@link HttpServletRequest} that should not use CSRF Protection
@ -229,7 +217,13 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
@SuppressWarnings("unchecked")
@Override
public void configure(H http) {
CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository);
CsrfFilter filter;
if (this.requestHandler != null) {
filter = new CsrfFilter(this.requestHandler);
}
else {
filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.csrfTokenRepository));
}
RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
if (requireCsrfProtectionMatcher != null) {
filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
@ -246,12 +240,6 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
if (sessionConfigurer != null) {
sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy());
}
if (this.requestHandler != null) {
filter.setRequestHandler(this.requestHandler);
}
if (this.requestResolver != null) {
filter.setRequestResolver(this.requestResolver);
}
filter = postProcess(filter);
http.addFilter(filter);
}

View File

@ -41,6 +41,7 @@ import org.springframework.security.web.access.DelegatingAccessDeniedHandler;
import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
@ -73,8 +74,6 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
private static final String ATT_REQUEST_HANDLER = "request-handler-ref";
private static final String ATT_REQUEST_RESOLVER = "request-resolver-ref";
private String csrfRepositoryRef;
private BeanDefinition csrfFilter;
@ -83,8 +82,6 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
private String requestHandlerRef;
private String requestResolverRef;
@Override
public BeanDefinition parse(Element element, ParserContext pc) {
boolean disabled = element != null && "true".equals(element.getAttribute("disabled"));
@ -104,7 +101,6 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY);
this.requestMatcherRef = element.getAttribute(ATT_MATCHER);
this.requestHandlerRef = element.getAttribute(ATT_REQUEST_HANDLER);
this.requestResolverRef = element.getAttribute(ATT_REQUEST_RESOLVER);
}
if (!StringUtils.hasText(this.csrfRepositoryRef)) {
RootBeanDefinition csrfTokenRepository = new RootBeanDefinition(HttpSessionCsrfTokenRepository.class);
@ -116,16 +112,18 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
new BeanComponentDefinition(lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef));
}
BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(CsrfFilter.class);
builder.addConstructorArgReference(this.csrfRepositoryRef);
if (!StringUtils.hasText(this.requestHandlerRef)) {
BeanDefinition csrfTokenRequestHandler = BeanDefinitionBuilder
.rootBeanDefinition(CsrfTokenRepositoryRequestHandler.class)
.addConstructorArgReference(this.csrfRepositoryRef).getBeanDefinition();
builder.addConstructorArgValue(csrfTokenRequestHandler);
}
else {
builder.addConstructorArgReference(this.requestHandlerRef);
}
if (StringUtils.hasText(this.requestMatcherRef)) {
builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef);
}
if (StringUtils.hasText(this.requestHandlerRef)) {
builder.addPropertyReference("requestHandler", this.requestHandlerRef);
}
if (StringUtils.hasText(this.requestResolverRef)) {
builder.addPropertyReference("requestResolver", this.requestResolverRef);
}
this.csrfFilter = builder.getBeanDefinition();
return this.csrfFilter;
}

View File

@ -1154,9 +1154,6 @@ csrf-options.attlist &=
csrf-options.attlist &=
## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor.
attribute request-handler-ref { xsd:token }?
csrf-options.attlist &=
## The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor.
attribute request-resolver-ref { xsd:token }?
headers =
## Element for configuration of the HeaderWritersFilter. Enables easy setting for the X-Frame-Options, X-XSS-Protection and X-Content-Type-Options headers.

View File

@ -3258,13 +3258,7 @@
</xs:attribute>
<xs:attribute name="request-handler-ref" type="xs:token">
<xs:annotation>
<xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor.
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="request-resolver-ref" type="xs:token">
<xs:annotation>
<xs:documentation>The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor.
<xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler.
</xs:documentation>
</xs:annotation>
</xs:attribute>

View File

@ -33,7 +33,7 @@ import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.csrf.CsrfTokenRequestProcessor;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
@ -85,7 +85,7 @@ public class DeferHttpSessionJavaConfigTests {
csrfRepository.setDeferLoadToken(true);
HttpSessionRequestCache requestCache = new HttpSessionRequestCache();
requestCache.setMatchingRequestParameterName("continue");
CsrfTokenRequestProcessor requestHandler = new CsrfTokenRequestProcessor();
CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler();
requestHandler.setCsrfRequestAttributeName("_csrf");
// @formatter:off
http

View File

@ -44,7 +44,7 @@ import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRequestProcessor;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.firewall.StrictHttpFirewall;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@ -422,8 +422,7 @@ public class CsrfConfigurerTests {
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor();
CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository);
CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
this.mvc.perform(get("/login")).andExpect(status().isOk())
.andExpect(content().string(containsString(csrfToken.getToken())));
@ -440,8 +439,7 @@ public class CsrfConfigurerTests {
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken);
given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor();
CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository);
CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
// @formatter:off
@ -803,7 +801,7 @@ public class CsrfConfigurerTests {
@EnableWebSecurity
static class CsrfTokenRequestProcessorConfig {
static CsrfTokenRequestProcessor PROCESSOR;
static CsrfTokenRepositoryRequestHandler HANDLER;
@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
@ -813,10 +811,7 @@ public class CsrfConfigurerTests {
.anyRequest().authenticated()
)
.formLogin(Customizer.withDefaults())
.csrf((csrf) -> csrf
.csrfTokenRequestHandler(PROCESSOR)
.csrfTokenRequestResolver(PROCESSOR)
);
.csrf((csrf) -> csrf.csrfTokenRequestHandler(HANDLER));
// @formatter:on
return http.build();

View File

@ -26,7 +26,7 @@
<csrf request-handler-ref="requestHandler"/>
</http>
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestProcessor"
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler"
p:csrfRequestAttributeName="csrf-attribute-name"/>
<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
</b:beans>

View File

@ -42,7 +42,7 @@
<b:bean id="csrfRepository" class="org.springframework.security.web.csrf.LazyCsrfTokenRepository"
c:delegate-ref="httpSessionCsrfRepository"
p:deferLoadToken="true"/>
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestProcessor"
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler"
p:csrfRequestAttributeName="_csrf"/>
<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
</b:beans>

View File

@ -777,11 +777,7 @@ The default is `HttpSessionCsrfTokenRepository`.
[[nsa-csrf-request-handler-ref]]
* **request-handler-ref**
The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRequestProcessor`.
[[nsa-csrf-request-resolver-ref]]
* **request-resolver-ref**
The optional `CsrfTokenRequestResolver` to use. The default is `CsrfTokenRequestProcessor`.
The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRepositoryRequestHandler`.
[[nsa-csrf-request-matcher-ref]]
* **request-matcher-ref**

View File

@ -31,8 +31,8 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestProcessor;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.context.WebApplicationContext;
@ -48,7 +48,7 @@ public abstract class WebTestUtils {
private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository();
private static final CsrfTokenRequestProcessor DEFAULT_CSRF_PROCESSOR = new CsrfTokenRequestProcessor();
private static final CsrfTokenRepositoryRequestHandler DEFAULT_CSRF_HANDLER = new CsrfTokenRepositoryRequestHandler();
private WebTestUtils() {
}
@ -104,7 +104,7 @@ public abstract class WebTestUtils {
public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) {
CsrfFilter filter = findFilter(request, CsrfFilter.class);
if (filter == null) {
return DEFAULT_CSRF_PROCESSOR;
return DEFAULT_CSRF_HANDLER;
}
return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler");
}

View File

@ -39,7 +39,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRequestProcessor;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
import org.springframework.web.context.WebApplicationContext;
@ -75,19 +75,22 @@ public class WebTestUtilsTests {
@Test
public void getCsrfTokenRepositorytNoWac() {
assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class);
assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request))
.isInstanceOf(CsrfTokenRepositoryRequestHandler.class);
}
@Test
public void getCsrfTokenRepositorytNoSecurity() {
loadConfig(Config.class);
assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class);
assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request))
.isInstanceOf(CsrfTokenRepositoryRequestHandler.class);
}
@Test
public void getCsrfTokenRepositorytSecurityNoCsrf() {
loadConfig(SecurityNoCsrfConfig.class);
assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class);
assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request))
.isInstanceOf(CsrfTokenRepositoryRequestHandler.class);
}
@Test

View File

@ -48,10 +48,7 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
* @param csrfTokenRepository the {@link CsrfTokenRepository} to use
*/
public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
CsrfTokenRequestProcessor processor = new CsrfTokenRequestProcessor();
processor.setTokenRepository(csrfTokenRepository);
this.requestHandler = processor;
this.requestHandler = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
this.csrfTokenRepository = csrfTokenRepository;
}

View File

@ -82,20 +82,30 @@ public final class CsrfFilter extends OncePerRequestFilter {
private final Log logger = LogFactory.getLog(getClass());
private final CsrfTokenRequestHandler requestHandler;
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
private CsrfTokenRequestHandler requestHandler;
private CsrfTokenRequestResolver requestResolver;
/**
* Creates a new instance.
* @param csrfTokenRepository the {@link CsrfTokenRepository} to use
* @deprecated Use {@link CsrfFilter#CsrfFilter(CsrfTokenRequestHandler)} instead
*/
@Deprecated
public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
csrfTokenRequestProcessor.setTokenRepository(csrfTokenRepository);
this.requestHandler = csrfTokenRequestProcessor;
this.requestResolver = csrfTokenRequestProcessor;
this(new CsrfTokenRepositoryRequestHandler(csrfTokenRepository));
}
/**
* Creates a new instance.
* @param requestHandler the {@link CsrfTokenRequestHandler} to use. Default is
* {@link CsrfTokenRepositoryRequestHandler}.
*/
public CsrfFilter(CsrfTokenRequestHandler requestHandler) {
Assert.notNull(requestHandler, "requestHandler cannot be null");
this.requestHandler = requestHandler;
}
@Override
@ -116,7 +126,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
return;
}
CsrfToken csrfToken = deferredCsrfToken.get();
String actualToken = this.requestResolver.resolveCsrfTokenValue(request, csrfToken);
String actualToken = this.requestHandler.resolveCsrfTokenValue(request, csrfToken);
if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
boolean missingToken = deferredCsrfToken.isGenerated();
this.logger.debug(
@ -164,36 +174,6 @@ public final class CsrfFilter extends OncePerRequestFilter {
this.accessDeniedHandler = accessDeniedHandler;
}
/**
* Specifies a {@link CsrfTokenRequestHandler} that is used to make the
* {@link CsrfToken} available as a request attribute.
*
* <p>
* The default is {@link CsrfTokenRequestProcessor}.
* </p>
* @param requestHandler the {@link CsrfTokenRequestHandler} to use
* @since 5.8
*/
public void setRequestHandler(CsrfTokenRequestHandler requestHandler) {
Assert.notNull(requestHandler, "requestHandler cannot be null");
this.requestHandler = requestHandler;
}
/**
* Specifies a {@link CsrfTokenRequestResolver} that is used to resolve the token
* value from the request.
*
* <p>
* The default is {@link CsrfTokenRequestProcessor}.
* </p>
* @param requestResolver the {@link CsrfTokenRequestResolver} to use
* @since 5.8
*/
public void setRequestResolver(CsrfTokenRequestResolver requestResolver) {
Assert.notNull(requestResolver, "requestResolver cannot be null");
this.requestResolver = requestResolver;
}
/**
* Constant time comparison to prevent against timing attacks.
* @param expected

View File

@ -24,28 +24,34 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.util.Assert;
/**
* An implementation of the {@link CsrfTokenRequestHandler} and
* {@link CsrfTokenRequestResolver} interfaces that is capable of making the
* {@link CsrfToken} available as a request attribute and resolving the token value as
* either a header or parameter value of the request.
* An implementation of the {@link CsrfTokenRequestHandler} interface that is capable of
* making the {@link CsrfToken} available as a request attribute and resolving the token
* value as either a header or parameter value of the request.
*
* @author Steve Riesenberg
* @since 5.8
*/
public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfTokenRequestResolver {
public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandler {
private final CsrfTokenRepository csrfTokenRepository;
private String csrfRequestAttributeName;
private CsrfTokenRepository tokenRepository = new HttpSessionCsrfTokenRepository();
/**
* Creates a new instance.
*/
public CsrfTokenRepositoryRequestHandler() {
this(new HttpSessionCsrfTokenRepository());
}
/**
* Sets the {@link CsrfTokenRepository} to use.
* @param tokenRepository the {@link CsrfTokenRepository} to use. Default
* Creates a new instance.
* @param csrfTokenRepository the {@link CsrfTokenRepository} to use. Default
* {@link HttpSessionCsrfTokenRepository}
*/
public void setTokenRepository(CsrfTokenRepository tokenRepository) {
Assert.notNull(tokenRepository, "tokenRepository cannot be null");
this.tokenRepository = tokenRepository;
public CsrfTokenRepositoryRequestHandler(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
this.csrfTokenRepository = csrfTokenRepository;
}
/**
@ -75,17 +81,6 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfT
return deferredCsrfToken;
}
@Override
public String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {
Assert.notNull(request, "request cannot be null");
Assert.notNull(csrfToken, "csrfToken cannot be null");
String actualToken = request.getHeader(csrfToken.getHeaderName());
if (actualToken == null) {
actualToken = request.getParameter(csrfToken.getParameterName());
}
return actualToken;
}
private static final class SupplierCsrfToken implements CsrfToken {
private final Supplier<CsrfToken> csrfTokenSupplier;
@ -150,11 +145,12 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfT
if (this.csrfToken != null) {
return;
}
this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.loadToken(this.request);
this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.loadToken(this.request);
this.missingToken = (this.csrfToken == null);
if (this.missingToken) {
this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.generateToken(this.request);
CsrfTokenRequestProcessor.this.tokenRepository.saveToken(this.csrfToken, this.request, this.response);
this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.generateToken(this.request);
CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.saveToken(this.csrfToken, this.request,
this.response);
}
}

View File

@ -19,18 +19,20 @@ package org.springframework.security.web.csrf;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.util.Assert;
/**
* A callback interface that is used to determine the {@link CsrfToken} to use and make
* the {@link CsrfToken} available as a request attribute. Implementations of this
* interface may choose to perform additional tasks or customize how the token is made
* available to the application through request attributes.
* An interface that is used to determine the {@link CsrfToken} to use and make the
* {@link CsrfToken} available as a request attribute. Implementations of this interface
* may choose to perform additional tasks or customize how the token is made available to
* the application through request attributes.
*
* @author Steve Riesenberg
* @since 5.8
* @see CsrfTokenRequestProcessor
* @see CsrfTokenRepositoryRequestHandler
*/
@FunctionalInterface
public interface CsrfTokenRequestHandler {
public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver {
/**
* Handles a request using a {@link CsrfToken}.
@ -39,4 +41,15 @@ public interface CsrfTokenRequestHandler {
*/
DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response);
@Override
default String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {
Assert.notNull(request, "request cannot be null");
Assert.notNull(csrfToken, "csrfToken cannot be null");
String actualToken = request.getHeader(csrfToken.getHeaderName());
if (actualToken == null) {
actualToken = request.getParameter(csrfToken.getParameterName());
}
return actualToken;
}
}

View File

@ -25,7 +25,7 @@ import javax.servlet.http.HttpServletRequest;
*
* @author Steve Riesenberg
* @since 5.8
* @see CsrfTokenRequestProcessor
* @see CsrfTokenRepositoryRequestHandler
*/
@FunctionalInterface
public interface CsrfTokenRequestResolver {

View File

@ -86,7 +86,11 @@ public class CsrfFilterTests {
}
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
CsrfFilter filter = new CsrfFilter(repository);
return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository));
}
private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) {
CsrfFilter filter = new CsrfFilter(requestHandler);
filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
filter.setAccessDeniedHandler(this.deniedHandler);
return filter;
@ -99,7 +103,7 @@ public class CsrfFilterTests {
@Test
public void constructorNullRepository() {
assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null));
assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null));
}
// SEC-2276
@ -249,7 +253,7 @@ public class CsrfFilterTests {
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
this.filter = new CsrfFilter(this.tokenRepository);
this.filter = createCsrfFilter(this.tokenRepository);
this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
resetRequestResponse();
@ -269,7 +273,7 @@ public class CsrfFilterTests {
*/
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception {
this.filter = new CsrfFilter(this.tokenRepository);
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
resetRequestResponse();
@ -284,7 +288,7 @@ public class CsrfFilterTests {
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
this.filter = new CsrfFilter(this.tokenRepository);
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
resetRequestResponse();
@ -299,7 +303,7 @@ public class CsrfFilterTests {
@Test
public void doFilterDefaultAccessDenied() throws ServletException, IOException {
this.filter = new CsrfFilter(this.tokenRepository);
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
@ -313,7 +317,7 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception {
CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
CsrfFilter filter = new CsrfFilter(repository);
CsrfFilter filter = createCsrfFilter(repository);
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
MockHttpServletRequest request = new MockHttpServletRequest();
CsrfFilter.skipRequest(request);
@ -340,25 +344,13 @@ public class CsrfFilterTests {
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
given(requestHandler.handle(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.filter.setRequestHandler(requestHandler);
this.filter = createCsrfFilter(requestHandler);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
verify(requestHandler).handle(eq(this.request), eq(this.response));
verify(this.filterChain).doFilter(this.request, this.response);
}
@Test
public void doFilterWhenRequestResolverThenUsed() throws Exception {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
CsrfTokenRequestResolver requestResolver = mock(CsrfTokenRequestResolver.class);
given(requestResolver.resolveCsrfTokenValue(this.request, this.token)).willReturn(this.token.getToken());
this.filter.setRequestResolver(requestResolver);
this.filter.doFilter(this.request, this.response, this.filterChain);
verify(requestResolver).resolveCsrfTokenValue(this.request, this.token);
verify(this.filterChain).doFilter(this.request, this.response);
}
@Test
public void setRequireCsrfProtectionMatcherNull() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequireCsrfProtectionMatcher(null));
@ -373,16 +365,14 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
throws ServletException, IOException {
CsrfFilter filter = createCsrfFilter(this.tokenRepository);
String csrfAttrName = "_csrf";
CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository);
csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName);
filter.setRequestHandler(csrfTokenRequestProcessor);
CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
this.filter = createCsrfFilter(requestHandler);
CsrfToken expectedCsrfToken = spy(this.token);
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
filter.doFilter(this.request, this.response, this.filterChain);
this.filter.doFilter(this.request, this.response, this.filterChain);
verifyNoInteractions(expectedCsrfToken);
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
@ -410,6 +400,6 @@ public class CsrfFilterTests {
return this.isGenerated;
}
};
}
}

View File

@ -31,13 +31,13 @@ import static org.mockito.BDDMockito.given;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
/**
* Tests for {@link CsrfTokenRequestProcessor}.
* Tests for {@link CsrfTokenRepositoryRequestHandler}.
*
* @author Steve Riesenberg
* @since 5.8
*/
@ExtendWith(MockitoExtension.class)
public class CsrfTokenRequestProcessorTests {
public class CsrfTokenRepositoryRequestHandlerTests {
@Mock
CsrfTokenRepository tokenRepository;
@ -48,34 +48,48 @@ public class CsrfTokenRequestProcessorTests {
private CsrfToken token;
private CsrfTokenRequestProcessor processor;
private CsrfTokenRepositoryRequestHandler handler;
@BeforeEach
public void setup() {
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue");
this.processor = new CsrfTokenRequestProcessor();
this.processor.setTokenRepository(this.tokenRepository);
this.handler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
}
@Test
public void constructorWhenCsrfTokenRepositoryIsNullThenThrowsIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> new CsrfTokenRepositoryRequestHandler(null))
.withMessage("csrfTokenRepository cannot be null");
// @formatter:on
}
@Test
public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(null, this.response))
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.handle(null, this.response))
.withMessage("request cannot be null");
// @formatter:on
}
@Test
public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, null))
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.handle(this.request, null))
.withMessage("response cannot be null");
// @formatter:on
}
@Test
public void handleWhenCsrfRequestAttributeSetThenUsed() {
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
this.processor.setCsrfRequestAttributeName("_csrf");
this.processor.handle(this.request, this.response);
this.handler.setCsrfRequestAttributeName("_csrf");
this.handler.handle(this.request, this.response);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token);
}
@ -83,40 +97,46 @@ public class CsrfTokenRequestProcessorTests {
@Test
public void handleWhenValidParametersThenRequestAttributesSet() {
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
this.processor.handle(this.request, this.response);
this.handler.handle(this.request, this.response);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
}
@Test
public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.resolveCsrfTokenValue(null, this.token))
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))
.withMessage("request cannot be null");
// @formatter:on
}
@Test
public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.resolveCsrfTokenValue(this.request, null))
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null))
.withMessage("csrfToken cannot be null");
// @formatter:on
}
@Test
public void resolveCsrfTokenValueWhenTokenNotSetThenReturnsNull() {
String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isNull();
}
@Test
public void resolveCsrfTokenValueWhenParameterSetThenReturnsTokenValue() {
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isEqualTo(this.token.getToken());
}
@Test
public void resolveCsrfTokenValueWhenHeaderSetThenReturnsTokenValue() {
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isEqualTo(this.token.getToken());
}
@ -124,7 +144,7 @@ public class CsrfTokenRequestProcessorTests {
public void resolveCsrfTokenValueWhenHeaderAndParameterSetThenHeaderIsPreferred() {
this.request.addHeader(this.token.getHeaderName(), "header");
this.request.setParameter(this.token.getParameterName(), "parameter");
String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isEqualTo("header");
}