diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java index 4363129c2d..eaa466f4a8 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java @@ -36,8 +36,8 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfLogoutHandler; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; import org.springframework.security.web.csrf.CsrfTokenRequestHandler; -import org.springframework.security.web.csrf.CsrfTokenRequestResolver; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.LazyCsrfTokenRepository; import org.springframework.security.web.csrf.MissingCsrfTokenException; @@ -93,8 +93,6 @@ public final class CsrfConfigurer> private CsrfTokenRequestHandler requestHandler; - private CsrfTokenRequestResolver requestResolver; - private final ApplicationContext context; /** @@ -135,23 +133,13 @@ public final class CsrfConfigurer> * available as a request attribute. * @param requestHandler the {@link CsrfTokenRequestHandler} to use * @return the {@link CsrfConfigurer} for further customizations + * @since 5.8 */ public CsrfConfigurer csrfTokenRequestHandler(CsrfTokenRequestHandler requestHandler) { this.requestHandler = requestHandler; return this; } - /** - * Specify a {@link CsrfTokenRequestResolver} to use for resolving the token value - * from the request. - * @param requestResolver the {@link CsrfTokenRequestResolver} to use - * @return the {@link CsrfConfigurer} for further customizations - */ - public CsrfConfigurer csrfTokenRequestResolver(CsrfTokenRequestResolver requestResolver) { - this.requestResolver = requestResolver; - return this; - } - /** *

* Allows specifying {@link HttpServletRequest} that should not use CSRF Protection @@ -229,7 +217,13 @@ public final class CsrfConfigurer> @SuppressWarnings("unchecked") @Override public void configure(H http) { - CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository); + CsrfFilter filter; + if (this.requestHandler != null) { + filter = new CsrfFilter(this.requestHandler); + } + else { + filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.csrfTokenRepository)); + } RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher(); if (requireCsrfProtectionMatcher != null) { filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); @@ -246,12 +240,6 @@ public final class CsrfConfigurer> if (sessionConfigurer != null) { sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy()); } - if (this.requestHandler != null) { - filter.setRequestHandler(this.requestHandler); - } - if (this.requestResolver != null) { - filter.setRequestResolver(this.requestResolver); - } filter = postProcess(filter); http.addFilter(filter); } diff --git a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java index 49ba28cab6..12eee0b11f 100644 --- a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java @@ -40,6 +40,7 @@ import org.springframework.security.web.access.DelegatingAccessDeniedHandler; import org.springframework.security.web.csrf.CsrfAuthenticationStrategy; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfLogoutHandler; +import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.LazyCsrfTokenRepository; import org.springframework.security.web.csrf.MissingCsrfTokenException; @@ -72,8 +73,6 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { private static final String ATT_REQUEST_HANDLER = "request-handler-ref"; - private static final String ATT_REQUEST_RESOLVER = "request-resolver-ref"; - private String csrfRepositoryRef; private BeanDefinition csrfFilter; @@ -82,8 +81,6 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { private String requestHandlerRef; - private String requestResolverRef; - @Override public BeanDefinition parse(Element element, ParserContext pc) { boolean disabled = element != null && "true".equals(element.getAttribute("disabled")); @@ -103,7 +100,6 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY); this.requestMatcherRef = element.getAttribute(ATT_MATCHER); this.requestHandlerRef = element.getAttribute(ATT_REQUEST_HANDLER); - this.requestResolverRef = element.getAttribute(ATT_REQUEST_RESOLVER); } if (!StringUtils.hasText(this.csrfRepositoryRef)) { RootBeanDefinition csrfTokenRepository = new RootBeanDefinition(HttpSessionCsrfTokenRepository.class); @@ -115,16 +111,18 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { new BeanComponentDefinition(lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef)); } BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(CsrfFilter.class); - builder.addConstructorArgReference(this.csrfRepositoryRef); + if (!StringUtils.hasText(this.requestHandlerRef)) { + BeanDefinition csrfTokenRequestHandler = BeanDefinitionBuilder + .rootBeanDefinition(CsrfTokenRepositoryRequestHandler.class) + .addConstructorArgReference(this.csrfRepositoryRef).getBeanDefinition(); + builder.addConstructorArgValue(csrfTokenRequestHandler); + } + else { + builder.addConstructorArgReference(this.requestHandlerRef); + } if (StringUtils.hasText(this.requestMatcherRef)) { builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef); } - if (StringUtils.hasText(this.requestHandlerRef)) { - builder.addPropertyReference("requestHandler", this.requestHandlerRef); - } - if (StringUtils.hasText(this.requestResolverRef)) { - builder.addPropertyReference("requestResolver", this.requestResolverRef); - } this.csrfFilter = builder.getBeanDefinition(); return this.csrfFilter; } diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc index 0738bd164e..5e61b3ee74 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc @@ -1154,9 +1154,6 @@ csrf-options.attlist &= csrf-options.attlist &= ## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor. attribute request-handler-ref { xsd:token }? -csrf-options.attlist &= - ## The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor. - attribute request-resolver-ref { xsd:token }? headers = ## Element for configuration of the HeaderWritersFilter. Enables easy setting for the X-Frame-Options, X-XSS-Protection and X-Content-Type-Options headers. diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd index fbc507bdcf..b5642bb293 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd @@ -3258,13 +3258,7 @@ - The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor. - - - - - - The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor. + The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler. diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java index bd0a93b4ae..2a6ab1ba06 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java @@ -43,7 +43,7 @@ import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.CsrfTokenRequestProcessor; +import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.firewall.StrictHttpFirewall; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; @@ -424,8 +424,7 @@ public class CsrfConfigurerTests { CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken); - CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor(); - CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository); + CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository); this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire(); this.mvc.perform(get("/login")).andExpect(status().isOk()) .andExpect(content().string(containsString(csrfToken.getToken()))); @@ -442,8 +441,7 @@ public class CsrfConfigurerTests { CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken); given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken); - CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor(); - CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository); + CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository); this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire(); // @formatter:off @@ -823,7 +821,7 @@ public class CsrfConfigurerTests { @EnableWebSecurity static class CsrfTokenRequestProcessorConfig { - static CsrfTokenRequestProcessor PROCESSOR; + static CsrfTokenRepositoryRequestHandler HANDLER; @Bean SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { @@ -833,10 +831,7 @@ public class CsrfConfigurerTests { .anyRequest().authenticated() ) .formLogin(Customizer.withDefaults()) - .csrf((csrf) -> csrf - .csrfTokenRequestHandler(PROCESSOR) - .csrfTokenRequestResolver(PROCESSOR) - ); + .csrf((csrf) -> csrf.csrfTokenRequestHandler(HANDLER)); // @formatter:on return http.build(); diff --git a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml index 37950840c8..2ea2dd7c87 100644 --- a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml +++ b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml @@ -26,7 +26,7 @@ - diff --git a/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml b/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml index 2a13f5fa4d..23411a7e77 100644 --- a/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml +++ b/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml @@ -40,7 +40,7 @@ - diff --git a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc index bb146d8e47..aca8ca3e51 100644 --- a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc +++ b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc @@ -776,11 +776,7 @@ The default is `HttpSessionCsrfTokenRepository`. [[nsa-csrf-request-handler-ref]] * **request-handler-ref** -The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRequestProcessor`. - -[[nsa-csrf-request-resolver-ref]] -* **request-resolver-ref** -The optional `CsrfTokenRequestResolver` to use. The default is `CsrfTokenRequestProcessor`. +The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRepositoryRequestHandler`. [[nsa-csrf-request-matcher-ref]] * **request-matcher-ref** diff --git a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java index ffb2c131a2..e9954078a8 100644 --- a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java +++ b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java @@ -31,8 +31,8 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; import org.springframework.security.web.csrf.CsrfTokenRequestHandler; -import org.springframework.security.web.csrf.CsrfTokenRequestProcessor; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.web.context.WebApplicationContext; @@ -48,7 +48,7 @@ public abstract class WebTestUtils { private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository(); - private static final CsrfTokenRequestProcessor DEFAULT_CSRF_PROCESSOR = new CsrfTokenRequestProcessor(); + private static final CsrfTokenRepositoryRequestHandler DEFAULT_CSRF_HANDLER = new CsrfTokenRepositoryRequestHandler(); private WebTestUtils() { } @@ -104,7 +104,7 @@ public abstract class WebTestUtils { public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) { CsrfFilter filter = findFilter(request, CsrfFilter.class); if (filter == null) { - return DEFAULT_CSRF_PROCESSOR; + return DEFAULT_CSRF_HANDLER; } return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler"); } diff --git a/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java b/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java index 38b4faa7c4..7bbc4998ef 100644 --- a/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java +++ b/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java @@ -39,7 +39,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.CsrfTokenRequestProcessor; +import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.web.context.WebApplicationContext; @@ -75,19 +75,22 @@ public class WebTestUtilsTests { @Test public void getCsrfTokenRepositorytNoWac() { - assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class); + assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)) + .isInstanceOf(CsrfTokenRepositoryRequestHandler.class); } @Test public void getCsrfTokenRepositorytNoSecurity() { loadConfig(Config.class); - assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class); + assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)) + .isInstanceOf(CsrfTokenRepositoryRequestHandler.class); } @Test public void getCsrfTokenRepositorytSecurityNoCsrf() { loadConfig(SecurityNoCsrfConfig.class); - assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class); + assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)) + .isInstanceOf(CsrfTokenRepositoryRequestHandler.class); } @Test diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java index 5da2cf58ca..850cbbc37b 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java @@ -47,10 +47,7 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt * @param csrfTokenRepository the {@link CsrfTokenRepository} to use */ public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) { - Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); - CsrfTokenRequestProcessor processor = new CsrfTokenRequestProcessor(); - processor.setTokenRepository(csrfTokenRepository); - this.requestHandler = processor; + this.requestHandler = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository); this.csrfTokenRepository = csrfTokenRepository; } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index eb2ab9f979..4b77201e6f 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -81,20 +81,30 @@ public final class CsrfFilter extends OncePerRequestFilter { private final Log logger = LogFactory.getLog(getClass()); + private final CsrfTokenRequestHandler requestHandler; + private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER; private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); - private CsrfTokenRequestHandler requestHandler; - - private CsrfTokenRequestResolver requestResolver; - + /** + * Creates a new instance. + * @param csrfTokenRepository the {@link CsrfTokenRepository} to use + * @deprecated Use {@link CsrfFilter#CsrfFilter(CsrfTokenRequestHandler)} instead + */ + @Deprecated public CsrfFilter(CsrfTokenRepository csrfTokenRepository) { - Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); - CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor(); - csrfTokenRequestProcessor.setTokenRepository(csrfTokenRepository); - this.requestHandler = csrfTokenRequestProcessor; - this.requestResolver = csrfTokenRequestProcessor; + this(new CsrfTokenRepositoryRequestHandler(csrfTokenRepository)); + } + + /** + * Creates a new instance. + * @param requestHandler the {@link CsrfTokenRequestHandler} to use. Default is + * {@link CsrfTokenRepositoryRequestHandler}. + */ + public CsrfFilter(CsrfTokenRequestHandler requestHandler) { + Assert.notNull(requestHandler, "requestHandler cannot be null"); + this.requestHandler = requestHandler; } @Override @@ -115,7 +125,7 @@ public final class CsrfFilter extends OncePerRequestFilter { return; } CsrfToken csrfToken = deferredCsrfToken.get(); - String actualToken = this.requestResolver.resolveCsrfTokenValue(request, csrfToken); + String actualToken = this.requestHandler.resolveCsrfTokenValue(request, csrfToken); if (!equalsConstantTime(csrfToken.getToken(), actualToken)) { boolean missingToken = deferredCsrfToken.isGenerated(); this.logger.debug( @@ -163,36 +173,6 @@ public final class CsrfFilter extends OncePerRequestFilter { this.accessDeniedHandler = accessDeniedHandler; } - /** - * Specifies a {@link CsrfTokenRequestHandler} that is used to make the - * {@link CsrfToken} available as a request attribute. - * - *

- * The default is {@link CsrfTokenRequestProcessor}. - *

- * @param requestHandler the {@link CsrfTokenRequestHandler} to use - * @since 5.8 - */ - public void setRequestHandler(CsrfTokenRequestHandler requestHandler) { - Assert.notNull(requestHandler, "requestHandler cannot be null"); - this.requestHandler = requestHandler; - } - - /** - * Specifies a {@link CsrfTokenRequestResolver} that is used to resolve the token - * value from the request. - * - *

- * The default is {@link CsrfTokenRequestProcessor}. - *

- * @param requestResolver the {@link CsrfTokenRequestResolver} to use - * @since 5.8 - */ - public void setRequestResolver(CsrfTokenRequestResolver requestResolver) { - Assert.notNull(requestResolver, "requestResolver cannot be null"); - this.requestResolver = requestResolver; - } - /** * Constant time comparison to prevent against timing attacks. * @param expected diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandler.java similarity index 72% rename from web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java rename to web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandler.java index dec8e54e8b..ef05dc776b 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandler.java @@ -24,28 +24,34 @@ import jakarta.servlet.http.HttpServletResponse; import org.springframework.util.Assert; /** - * An implementation of the {@link CsrfTokenRequestHandler} and - * {@link CsrfTokenRequestResolver} interfaces that is capable of making the - * {@link CsrfToken} available as a request attribute and resolving the token value as - * either a header or parameter value of the request. + * An implementation of the {@link CsrfTokenRequestHandler} interface that is capable of + * making the {@link CsrfToken} available as a request attribute and resolving the token + * value as either a header or parameter value of the request. * * @author Steve Riesenberg * @since 5.8 */ -public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfTokenRequestResolver { +public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandler { + + private final CsrfTokenRepository csrfTokenRepository; private String csrfRequestAttributeName = "_csrf"; - private CsrfTokenRepository tokenRepository = new HttpSessionCsrfTokenRepository(); + /** + * Creates a new instance. + */ + public CsrfTokenRepositoryRequestHandler() { + this(new HttpSessionCsrfTokenRepository()); + } /** - * Sets the {@link CsrfTokenRepository} to use. - * @param tokenRepository the {@link CsrfTokenRepository} to use. Default + * Creates a new instance. + * @param csrfTokenRepository the {@link CsrfTokenRepository} to use. Default * {@link HttpSessionCsrfTokenRepository} */ - public void setTokenRepository(CsrfTokenRepository tokenRepository) { - Assert.notNull(tokenRepository, "tokenRepository cannot be null"); - this.tokenRepository = tokenRepository; + public CsrfTokenRepositoryRequestHandler(CsrfTokenRepository csrfTokenRepository) { + Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); + this.csrfTokenRepository = csrfTokenRepository; } /** @@ -75,17 +81,6 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfT return deferredCsrfToken; } - @Override - public String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) { - Assert.notNull(request, "request cannot be null"); - Assert.notNull(csrfToken, "csrfToken cannot be null"); - String actualToken = request.getHeader(csrfToken.getHeaderName()); - if (actualToken == null) { - actualToken = request.getParameter(csrfToken.getParameterName()); - } - return actualToken; - } - private static final class SupplierCsrfToken implements CsrfToken { private final Supplier csrfTokenSupplier; @@ -150,11 +145,12 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfT if (this.csrfToken != null) { return; } - this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.loadToken(this.request); + this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.loadToken(this.request); this.missingToken = (this.csrfToken == null); if (this.missingToken) { - this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.generateToken(this.request); - CsrfTokenRequestProcessor.this.tokenRepository.saveToken(this.csrfToken, this.request, this.response); + this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.generateToken(this.request); + CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.saveToken(this.csrfToken, this.request, + this.response); } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java index 6fc4db61f5..f84e9b8cc1 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java @@ -19,18 +19,20 @@ package org.springframework.security.web.csrf; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import org.springframework.util.Assert; + /** - * A callback interface that is used to determine the {@link CsrfToken} to use and make - * the {@link CsrfToken} available as a request attribute. Implementations of this - * interface may choose to perform additional tasks or customize how the token is made - * available to the application through request attributes. + * An interface that is used to determine the {@link CsrfToken} to use and make the + * {@link CsrfToken} available as a request attribute. Implementations of this interface + * may choose to perform additional tasks or customize how the token is made available to + * the application through request attributes. * * @author Steve Riesenberg * @since 5.8 - * @see CsrfTokenRequestProcessor + * @see CsrfTokenRepositoryRequestHandler */ @FunctionalInterface -public interface CsrfTokenRequestHandler { +public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver { /** * Handles a request using a {@link CsrfToken}. @@ -39,4 +41,15 @@ public interface CsrfTokenRequestHandler { */ DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response); + @Override + default String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) { + Assert.notNull(request, "request cannot be null"); + Assert.notNull(csrfToken, "csrfToken cannot be null"); + String actualToken = request.getHeader(csrfToken.getHeaderName()); + if (actualToken == null) { + actualToken = request.getParameter(csrfToken.getParameterName()); + } + return actualToken; + } + } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java index 38089a066c..fc381d72c5 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java @@ -25,7 +25,7 @@ import jakarta.servlet.http.HttpServletRequest; * * @author Steve Riesenberg * @since 5.8 - * @see CsrfTokenRequestProcessor + * @see CsrfTokenRepositoryRequestHandler */ @FunctionalInterface public interface CsrfTokenRequestResolver { diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index 335a49471d..18e5a9a7ea 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -87,7 +87,11 @@ public class CsrfFilterTests { } private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) { - CsrfFilter filter = new CsrfFilter(repository); + return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository)); + } + + private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) { + CsrfFilter filter = new CsrfFilter(requestHandler); filter.setRequireCsrfProtectionMatcher(this.requestMatcher); filter.setAccessDeniedHandler(this.deniedHandler); return filter; @@ -100,7 +104,7 @@ public class CsrfFilterTests { @Test public void constructorNullRepository() { - assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null)); + assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null)); } // SEC-2276 @@ -250,7 +254,7 @@ public class CsrfFilterTests { @Test public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { - this.filter = new CsrfFilter(this.tokenRepository); + this.filter = createCsrfFilter(this.tokenRepository); this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { resetRequestResponse(); @@ -269,7 +273,7 @@ public class CsrfFilterTests { */ @Test public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception { - this.filter = new CsrfFilter(this.tokenRepository); + this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) { resetRequestResponse(); @@ -284,7 +288,7 @@ public class CsrfFilterTests { @Test public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { - this.filter = new CsrfFilter(this.tokenRepository); + this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) { resetRequestResponse(); @@ -299,7 +303,7 @@ public class CsrfFilterTests { @Test public void doFilterDefaultAccessDenied() throws ServletException, IOException { - this.filter = new CsrfFilter(this.tokenRepository); + this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); @@ -313,7 +317,7 @@ public class CsrfFilterTests { @Test public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception { CsrfTokenRepository repository = mock(CsrfTokenRepository.class); - CsrfFilter filter = new CsrfFilter(repository); + CsrfFilter filter = createCsrfFilter(repository); lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token); MockHttpServletRequest request = new MockHttpServletRequest(); CsrfFilter.skipRequest(request); @@ -340,25 +344,13 @@ public class CsrfFilterTests { CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); given(requestHandler.handle(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); - this.filter.setRequestHandler(requestHandler); + this.filter = createCsrfFilter(requestHandler); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); verify(requestHandler).handle(eq(this.request), eq(this.response)); verify(this.filterChain).doFilter(this.request, this.response); } - @Test - public void doFilterWhenRequestResolverThenUsed() throws Exception { - given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); - CsrfTokenRequestResolver requestResolver = mock(CsrfTokenRequestResolver.class); - given(requestResolver.resolveCsrfTokenValue(this.request, this.token)).willReturn(this.token.getToken()); - this.filter.setRequestResolver(requestResolver); - this.filter.doFilter(this.request, this.response, this.filterChain); - verify(requestResolver).resolveCsrfTokenValue(this.request, this.token); - verify(this.filterChain).doFilter(this.request, this.response); - } - @Test public void setRequireCsrfProtectionMatcherNull() { assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequireCsrfProtectionMatcher(null)); @@ -373,16 +365,14 @@ public class CsrfFilterTests { @Test public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet() throws ServletException, IOException { - CsrfFilter filter = createCsrfFilter(this.tokenRepository); String csrfAttrName = "_csrf"; - CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor(); - csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository); - csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName); - filter.setRequestHandler(csrfTokenRequestProcessor); + CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository); + requestHandler.setCsrfRequestAttributeName(csrfAttrName); + this.filter = createCsrfFilter(requestHandler); CsrfToken expectedCsrfToken = spy(this.token); given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken); - filter.doFilter(this.request, this.response, this.filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); verifyNoInteractions(expectedCsrfToken); CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); @@ -410,6 +400,6 @@ public class CsrfFilterTests { return this.isGenerated; } - }; + } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandlerTests.java similarity index 70% rename from web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java rename to web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandlerTests.java index 529ab5c821..0756447794 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandlerTests.java @@ -31,13 +31,13 @@ import static org.mockito.BDDMockito.given; import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; /** - * Tests for {@link CsrfTokenRequestProcessor}. + * Tests for {@link CsrfTokenRepositoryRequestHandler}. * * @author Steve Riesenberg * @since 5.8 */ @ExtendWith(MockitoExtension.class) -public class CsrfTokenRequestProcessorTests { +public class CsrfTokenRepositoryRequestHandlerTests { @Mock CsrfTokenRepository tokenRepository; @@ -48,34 +48,48 @@ public class CsrfTokenRequestProcessorTests { private CsrfToken token; - private CsrfTokenRequestProcessor processor; + private CsrfTokenRepositoryRequestHandler handler; @BeforeEach public void setup() { this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); - this.processor = new CsrfTokenRequestProcessor(); - this.processor.setTokenRepository(this.tokenRepository); + this.handler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository); + } + + @Test + public void constructorWhenCsrfTokenRepositoryIsNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new CsrfTokenRepositoryRequestHandler(null)) + .withMessage("csrfTokenRepository cannot be null"); + // @formatter:on } @Test public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(null, this.response)) + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.handle(null, this.response)) .withMessage("request cannot be null"); + // @formatter:on } @Test public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, null)) + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.handle(this.request, null)) .withMessage("response cannot be null"); + // @formatter:on } @Test public void handleWhenCsrfRequestAttributeSetThenUsed() { given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); - this.processor.setCsrfRequestAttributeName("_csrf"); - this.processor.handle(this.request, this.response); + this.handler.setCsrfRequestAttributeName("_csrf"); + this.handler.handle(this.request, this.response); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token); } @@ -83,40 +97,46 @@ public class CsrfTokenRequestProcessorTests { @Test public void handleWhenValidParametersThenRequestAttributesSet() { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); - this.processor.handle(this.request, this.response); + this.handler.handle(this.request, this.response); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token); } @Test public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.processor.resolveCsrfTokenValue(null, this.token)) + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token)) .withMessage("request cannot be null"); + // @formatter:on } @Test public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.processor.resolveCsrfTokenValue(this.request, null)) + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null)) .withMessage("csrfToken cannot be null"); + // @formatter:on } @Test public void resolveCsrfTokenValueWhenTokenNotSetThenReturnsNull() { - String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); assertThat(tokenValue).isNull(); } @Test public void resolveCsrfTokenValueWhenParameterSetThenReturnsTokenValue() { this.request.setParameter(this.token.getParameterName(), this.token.getToken()); - String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); assertThat(tokenValue).isEqualTo(this.token.getToken()); } @Test public void resolveCsrfTokenValueWhenHeaderSetThenReturnsTokenValue() { this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); - String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); assertThat(tokenValue).isEqualTo(this.token.getToken()); } @@ -124,7 +144,7 @@ public class CsrfTokenRequestProcessorTests { public void resolveCsrfTokenValueWhenHeaderAndParameterSetThenHeaderIsPreferred() { this.request.addHeader(this.token.getHeaderName(), "header"); this.request.setParameter(this.token.getParameterName(), "parameter"); - String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); assertThat(tokenValue).isEqualTo("header"); }