diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java index 1b755e315d..2d6d5cd09a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java @@ -36,7 +36,7 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfLogoutHandler; import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.CsrfTokenRequestResolver; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.LazyCsrfTokenRepository; @@ -91,7 +91,7 @@ public final class CsrfConfigurer> private SessionAuthenticationStrategy sessionAuthenticationStrategy; - private CsrfTokenRequestAttributeHandler requestAttributeHandler; + private CsrfTokenRequestHandler requestHandler; private CsrfTokenRequestResolver requestResolver; @@ -131,14 +131,13 @@ public final class CsrfConfigurer> } /** - * Specify a {@link CsrfTokenRequestAttributeHandler} to use for making the - * {@code CsrfToken} available as a request attribute. - * @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use + * Specify a {@link CsrfTokenRequestHandler} to use for making the {@code CsrfToken} + * available as a request attribute. + * @param requestHandler the {@link CsrfTokenRequestHandler} to use * @return the {@link CsrfConfigurer} for further customizations */ - public CsrfConfigurer csrfTokenRequestAttributeHandler( - CsrfTokenRequestAttributeHandler requestAttributeHandler) { - this.requestAttributeHandler = requestAttributeHandler; + public CsrfConfigurer csrfTokenRequestHandler(CsrfTokenRequestHandler requestHandler) { + this.requestHandler = requestHandler; return this; } @@ -247,8 +246,8 @@ public final class CsrfConfigurer> if (sessionConfigurer != null) { sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy()); } - if (this.requestAttributeHandler != null) { - filter.setRequestAttributeHandler(this.requestAttributeHandler); + if (this.requestHandler != null) { + filter.setRequestHandler(this.requestHandler); } if (this.requestResolver != null) { filter.setRequestResolver(this.requestResolver); @@ -343,8 +342,8 @@ public final class CsrfConfigurer> } CsrfAuthenticationStrategy csrfAuthenticationStrategy = new CsrfAuthenticationStrategy( this.csrfTokenRepository); - if (this.requestAttributeHandler != null) { - csrfAuthenticationStrategy.setRequestAttributeHandler(this.requestAttributeHandler); + if (this.requestHandler != null) { + csrfAuthenticationStrategy.setRequestHandler(this.requestHandler); } return csrfAuthenticationStrategy; } diff --git a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java index c76d4c4d11..ae2c614113 100644 --- a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java @@ -71,7 +71,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { private static final String ATT_REPOSITORY = "token-repository-ref"; - private static final String ATT_REQUEST_ATTRIBUTE_HANDLER = "request-attribute-handler-ref"; + private static final String ATT_REQUEST_HANDLER = "request-handler-ref"; private static final String ATT_REQUEST_RESOLVER = "request-resolver-ref"; @@ -81,7 +81,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { private String requestMatcherRef; - private String requestAttributeHandlerRef; + private String requestHandlerRef; private String requestResolverRef; @@ -103,7 +103,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { if (element != null) { this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY); this.requestMatcherRef = element.getAttribute(ATT_MATCHER); - this.requestAttributeHandlerRef = element.getAttribute(ATT_REQUEST_ATTRIBUTE_HANDLER); + this.requestHandlerRef = element.getAttribute(ATT_REQUEST_HANDLER); this.requestResolverRef = element.getAttribute(ATT_REQUEST_RESOLVER); } if (!StringUtils.hasText(this.csrfRepositoryRef)) { @@ -120,8 +120,8 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { if (StringUtils.hasText(this.requestMatcherRef)) { builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef); } - if (StringUtils.hasText(this.requestAttributeHandlerRef)) { - builder.addPropertyReference("requestAttributeHandler", this.requestAttributeHandlerRef); + if (StringUtils.hasText(this.requestHandlerRef)) { + builder.addPropertyReference("requestHandler", this.requestHandlerRef); } if (StringUtils.hasText(this.requestResolverRef)) { builder.addPropertyReference("requestResolver", this.requestResolverRef); diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc index dcec72a232..0738bd164e 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc @@ -1152,8 +1152,8 @@ csrf-options.attlist &= ## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository. attribute token-repository-ref { xsd:token }? csrf-options.attlist &= - ## The CsrfTokenRequestAttributeHandler to use. The default is CsrfTokenRequestProcessor. - attribute request-attribute-handler-ref { xsd:token }? + ## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor. + attribute request-handler-ref { xsd:token }? csrf-options.attlist &= ## The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor. attribute request-resolver-ref { xsd:token }? diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd index dc2911daac..fbc507bdcf 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd @@ -3256,9 +3256,9 @@ - + - The CsrfTokenRequestAttributeHandler to use. The default is CsrfTokenRequestProcessor. + The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor. diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java index b6cf0d68b5..f23ea3a398 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java @@ -85,8 +85,8 @@ public class DeferHttpSessionJavaConfigTests { csrfRepository.setDeferLoadToken(true); HttpSessionRequestCache requestCache = new HttpSessionRequestCache(); requestCache.setMatchingRequestParameterName("continue"); - CsrfTokenRequestProcessor requestAttributeHandler = new CsrfTokenRequestProcessor(); - requestAttributeHandler.setCsrfRequestAttributeName("_csrf"); + CsrfTokenRequestProcessor requestHandler = new CsrfTokenRequestProcessor(); + requestHandler.setCsrfRequestAttributeName("_csrf"); // @formatter:off http .requestCache((cache) -> cache @@ -102,7 +102,7 @@ public class DeferHttpSessionJavaConfigTests { .requireExplicitAuthenticationStrategy(true) ) .csrf((csrf) -> csrf - .csrfTokenRequestAttributeHandler(requestAttributeHandler) + .csrfTokenRequestHandler(requestHandler) .csrfTokenRepository(csrfRepository) ); // @formatter:on diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java index 3145696f72..4b6d162cb5 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java @@ -422,8 +422,8 @@ public class CsrfConfigurerTests { CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken); - CsrfTokenRequestProcessorConfig.REPO = csrfTokenRepository; CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor(); + CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository); this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire(); this.mvc.perform(get("/login")).andExpect(status().isOk()) .andExpect(content().string(containsString(csrfToken.getToken()))); @@ -438,10 +438,11 @@ public class CsrfConfigurerTests { public void loginWhenCsrfTokenRequestProcessorSetAndNormalCsrfTokenThenSuccess() throws Exception { CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); - given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(csrfToken); + given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken); given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken); - CsrfTokenRequestProcessorConfig.REPO = csrfTokenRepository; CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor(); + CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository); + this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire(); // @formatter:off MockHttpServletRequestBuilder loginRequest = post("/login") @@ -451,7 +452,6 @@ public class CsrfConfigurerTests { // @formatter:on this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); verify(csrfTokenRepository, times(2)).loadToken(any(HttpServletRequest.class)); - verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class)); verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class)); verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class), any(HttpServletResponse.class)); @@ -803,8 +803,6 @@ public class CsrfConfigurerTests { @EnableWebSecurity static class CsrfTokenRequestProcessorConfig { - static CsrfTokenRepository REPO; - static CsrfTokenRequestProcessor PROCESSOR; @Bean @@ -816,8 +814,7 @@ public class CsrfConfigurerTests { ) .formLogin(Customizer.withDefaults()) .csrf((csrf) -> csrf - .csrfTokenRepository(REPO) - .csrfTokenRequestAttributeHandler(PROCESSOR) + .csrfTokenRequestHandler(PROCESSOR) .csrfTokenRequestResolver(PROCESSOR) ); // @formatter:on diff --git a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java index e9220895fb..1f35855205 100644 --- a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpMethod; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.config.test.SpringTestContext; @@ -41,6 +42,7 @@ import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.stereotype.Controller; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -544,8 +546,9 @@ public class CsrfConfigTests { @Override public void match(MvcResult result) { MockHttpServletRequest request = result.getRequest(); - CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request); - assertThat(token).isNotNull(); + MockHttpServletResponse response = result.getResponse(); + DeferredCsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response); + assertThat(token.isGenerated()).isFalse(); } } @@ -561,7 +564,8 @@ public class CsrfConfigTests { @Override public void match(MvcResult result) throws Exception { MockHttpServletRequest request = result.getRequest(); - CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request); + MockHttpServletResponse response = result.getResponse(); + CsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response).get(); assertThat(token).isNotNull(); assertThat(token.getToken()).isEqualTo(this.token.apply(result)); } diff --git a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml index 541f66453f..37950840c8 100644 --- a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml +++ b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml @@ -23,10 +23,10 @@ http://www.springframework.org/schema/beans https://www.springframework.org/schema/beans/spring-beans.xsd"> - + - diff --git a/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml b/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml index 2efb29d03e..716a68fa04 100644 --- a/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml +++ b/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml @@ -30,7 +30,7 @@ security-context-explicit-save="true" use-authorization-manager="true"> - @@ -42,7 +42,7 @@ - diff --git a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc index c000883a37..1ffab189cb 100644 --- a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc +++ b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc @@ -775,9 +775,9 @@ It is highly recommended to leave CSRF protection enabled. The CsrfTokenRepository to use. The default is `HttpSessionCsrfTokenRepository`. -[[nsa-csrf-request-attribute-handler-ref]] -* **request-attribute-handler-ref** -The optional `CsrfTokenRequestAttributeHandler` to use. The default is `CsrfTokenRequestProcessor`. +[[nsa-csrf-request-handler-ref]] +* **request-handler-ref** +The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRequestProcessor`. [[nsa-csrf-request-resolver-ref]] * **request-resolver-ref** diff --git a/etc/checkstyle/checkstyle.xml b/etc/checkstyle/checkstyle.xml index 166b84e46f..40bce12ae5 100644 --- a/etc/checkstyle/checkstyle.xml +++ b/etc/checkstyle/checkstyle.xml @@ -17,6 +17,7 @@ + diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 4eb0f6b324..1419b0e497 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -94,7 +94,8 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfToken; -import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; @@ -508,14 +509,13 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); - if (!(repository instanceof TestCsrfTokenRepository)) { - repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository()); - WebTestUtils.setCsrfTokenRepository(request, repository); + CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request); + if (!(handler instanceof TestCsrfTokenRequestHandler)) { + handler = new TestCsrfTokenRequestHandler(handler); + WebTestUtils.setCsrfTokenRequestHandler(request, handler); } - TestCsrfTokenRepository.enable(request); - CsrfToken token = repository.generateToken(request); - repository.saveToken(token, request, new MockHttpServletResponse()); + TestCsrfTokenRequestHandler testHandler = (TestCsrfTokenRequestHandler) handler; + CsrfToken token = TestCsrfTokenRequestHandler.createTestCsrfToken(request); String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken(); if (this.asHeader) { request.addHeader(token.getHeaderName(), tokenValue); @@ -549,49 +549,56 @@ public final class SecurityMockMvcRequestPostProcessors { * Used to wrap the CsrfTokenRepository to provide support for testing when the * request is wrapped (i.e. Spring Session is in use). */ - static class TestCsrfTokenRepository implements CsrfTokenRepository { + static class TestCsrfTokenRequestHandler implements CsrfTokenRequestHandler { - static final String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".TOKEN"); + static final String TOKEN_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".TOKEN"); - static final String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".ENABLED"); + static final String ENABLED_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".ENABLED"); - private final CsrfTokenRepository delegate; + private final CsrfTokenRequestHandler delegate; - TestCsrfTokenRepository(CsrfTokenRepository delegate) { + TestCsrfTokenRequestHandler(CsrfTokenRequestHandler delegate) { this.delegate = delegate; } - @Override - public CsrfToken generateToken(HttpServletRequest request) { - return this.delegate.generateToken(request); + static CsrfToken createTestCsrfToken(HttpServletRequest request) { + CsrfToken existingToken = getExistingToken(request); + if (existingToken != null) { + return existingToken; + } + HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); + CsrfToken csrfToken = repository.generateToken(request); + request.setAttribute(ENABLED_ATTR_NAME, true); + request.setAttribute(TOKEN_ATTR_NAME, csrfToken); + return csrfToken; } - @Override - public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { - if (isEnabled(request)) { - request.setAttribute(TOKEN_ATTR_NAME, token); - } - else { - this.delegate.saveToken(token, request, response); - } - } - - @Override - public CsrfToken loadToken(HttpServletRequest request) { - if (isEnabled(request)) { - return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME); - } - else { - return this.delegate.loadToken(request); - } - } - - static void enable(HttpServletRequest request) { - request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE); + private static CsrfToken getExistingToken(HttpServletRequest request) { + Object existingToken = request.getAttribute(TOKEN_ATTR_NAME); + return (CsrfToken) existingToken; } boolean isEnabled(HttpServletRequest request) { - return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME)); + return getExistingToken(request) != null; + } + + @Override + public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) { + request.setAttribute(HttpServletResponse.class.getName(), response); + if (!isEnabled(request)) { + return this.delegate.handle(request, response); + } + return new DeferredCsrfToken() { + @Override + public CsrfToken get() { + return getExistingToken(request); + } + + @Override + public boolean isGenerated() { + return false; + } + }; } } diff --git a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java index c13ebdefe3..8f6d617730 100644 --- a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java +++ b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java @@ -31,6 +31,8 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; +import org.springframework.security.web.csrf.CsrfTokenRequestProcessor; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.web.context.WebApplicationContext; @@ -46,7 +48,7 @@ public abstract class WebTestUtils { private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository(); - private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository(); + private static final CsrfTokenRequestProcessor DEFAULT_CSRF_PROCESSOR = new CsrfTokenRequestProcessor(); private WebTestUtils() { } @@ -99,24 +101,24 @@ public abstract class WebTestUtils { * @return the {@link CsrfTokenRepository} for the specified * {@link HttpServletRequest} */ - public static CsrfTokenRepository getCsrfTokenRepository(HttpServletRequest request) { + public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) { CsrfFilter filter = findFilter(request, CsrfFilter.class); if (filter == null) { - return DEFAULT_TOKEN_REPO; + return DEFAULT_CSRF_PROCESSOR; } - return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository"); + return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler"); } /** * Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}. * @param request the {@link HttpServletRequest} to obtain the * {@link CsrfTokenRepository} - * @param repository the {@link CsrfTokenRepository} to set + * @param handler the {@link CsrfTokenRepository} to set */ - public static void setCsrfTokenRepository(HttpServletRequest request, CsrfTokenRepository repository) { + public static void setCsrfTokenRequestHandler(HttpServletRequest request, CsrfTokenRequestHandler handler) { CsrfFilter filter = findFilter(request, CsrfFilter.class); if (filter != null) { - ReflectionTestUtils.setField(filter, "tokenRepository", repository); + ReflectionTestUtils.setField(filter, "requestHandler", handler); } } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java index 9dea5175bf..374aa68414 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java @@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { public void defaults() { MockHttpServletRequest request = formLogin().buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("password")).isEqualTo("password"); assertThat(request.getMethod()).isEqualTo("POST"); @@ -67,7 +67,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret") .buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); @@ -80,7 +80,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2") .user("username", "admin").password("password", "secret").buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java index df6e7cfef2..c6856fb821 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java @@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { public void defaults() { MockHttpServletRequest request = logout().buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/logout"); @@ -63,7 +63,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { public void custom() { MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/admin/logout"); @@ -74,7 +74,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2") .buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2"); diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java deleted file mode 100644 index acb81a8134..0000000000 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2002-2016 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.test.web.servlet.request; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; - -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.builders.WebSecurity; -import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; -import org.springframework.security.test.web.support.WebTestUtils; -import org.springframework.security.web.csrf.CookieCsrfTokenRepository; -import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.test.context.ContextConfiguration; -import org.springframework.test.context.junit.jupiter.SpringExtension; -import org.springframework.test.context.web.WebAppConfiguration; -import org.springframework.web.context.WebApplicationContext; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; - -@ExtendWith(SpringExtension.class) -@ContextConfiguration -@WebAppConfiguration -public class SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests { - - @Autowired - private WebApplicationContext wac; - - // SEC-3836 - @Test - public void findCookieCsrfTokenRepository() { - MockHttpServletRequest request = post("/").buildRequest(this.wac.getServletContext()); - CsrfTokenRepository csrfTokenRepository = WebTestUtils.getCsrfTokenRepository(request); - assertThat(csrfTokenRepository).isNotNull(); - assertThat(csrfTokenRepository).isEqualTo(Config.cookieCsrfTokenRepository); - } - - @EnableWebSecurity - static class Config extends WebSecurityConfigurerAdapter { - - static CsrfTokenRepository cookieCsrfTokenRepository = new CookieCsrfTokenRepository(); - - @Override - protected void configure(HttpSecurity http) throws Exception { - http.csrf().csrfTokenRepository(cookieCsrfTokenRepository); - } - - @Override - public void configure(WebSecurity web) { - // Enable the DebugFilter - web.debug(true); - } - - } - -} diff --git a/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java b/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java index c304202b4d..220f8f19d3 100644 --- a/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java +++ b/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java @@ -39,6 +39,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRequestProcessor; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.web.context.WebApplicationContext; @@ -74,22 +75,19 @@ public class WebTestUtilsTests { @Test public void getCsrfTokenRepositorytNoWac() { - assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) - .isInstanceOf(HttpSessionCsrfTokenRepository.class); + assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class); } @Test public void getCsrfTokenRepositorytNoSecurity() { loadConfig(Config.class); - assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) - .isInstanceOf(HttpSessionCsrfTokenRepository.class); + assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class); } @Test public void getCsrfTokenRepositorytSecurityNoCsrf() { loadConfig(SecurityNoCsrfConfig.class); - assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) - .isInstanceOf(HttpSessionCsrfTokenRepository.class); + assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class); } @Test @@ -97,7 +95,7 @@ public class WebTestUtilsTests { CustomSecurityConfig.CONTEXT_REPO = this.contextRepo; CustomSecurityConfig.CSRF_REPO = this.csrfRepo; loadConfig(CustomSecurityConfig.class); - assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo); + // assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo); } // getSecurityContextRepository diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java index b61a20d7d5..a41669d8e6 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java @@ -41,7 +41,7 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt private final CsrfTokenRepository csrfTokenRepository; - private CsrfTokenRequestAttributeHandler requestAttributeHandler = new CsrfTokenRequestProcessor(); + private CsrfTokenRequestHandler requestHandler; /** * Creates a new instance @@ -49,30 +49,28 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt */ public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) { Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); + CsrfTokenRequestProcessor processor = new CsrfTokenRequestProcessor(); + processor.setTokenRepository(csrfTokenRepository); + this.requestHandler = processor; this.csrfTokenRepository = csrfTokenRepository; } /** - * Specify a {@link CsrfTokenRequestAttributeHandler} to use for making the - * {@code CsrfToken} available as a request attribute. - * @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use + * Specify a {@link CsrfTokenRequestHandler} to use for making the {@code CsrfToken} + * available as a request attribute. + * @param requestHandler the {@link CsrfTokenRequestHandler} to use */ - public void setRequestAttributeHandler(CsrfTokenRequestAttributeHandler requestAttributeHandler) { - Assert.notNull(requestAttributeHandler, "requestAttributeHandler cannot be null"); - this.requestAttributeHandler = requestAttributeHandler; + public void setRequestHandler(CsrfTokenRequestHandler requestHandler) { + Assert.notNull(requestHandler, "requestHandler cannot be null"); + this.requestHandler = requestHandler; } @Override public void onAuthentication(Authentication authentication, HttpServletRequest request, HttpServletResponse response) throws SessionAuthenticationException { - boolean containsToken = this.csrfTokenRepository.loadToken(request) != null; - if (containsToken) { - this.csrfTokenRepository.saveToken(null, request, response); - CsrfToken newToken = this.csrfTokenRepository.generateToken(request); - this.csrfTokenRepository.saveToken(newToken, request, response); - this.requestAttributeHandler.handle(request, response, () -> newToken); - this.logger.debug("Replaced CSRF Token"); - } + this.csrfTokenRepository.saveToken(null, request, response); + this.requestHandler.handle(request, response); + this.logger.debug("Replaced CSRF Token"); } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 0033bf571e..3dafeb2543 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -82,21 +82,19 @@ public final class CsrfFilter extends OncePerRequestFilter { private final Log logger = LogFactory.getLog(getClass()); - private final CsrfTokenRepository tokenRepository; - private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER; private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); - private CsrfTokenRequestAttributeHandler requestAttributeHandler; + private CsrfTokenRequestHandler requestHandler; private CsrfTokenRequestResolver requestResolver; public CsrfFilter(CsrfTokenRepository csrfTokenRepository) { Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); - this.tokenRepository = csrfTokenRepository; CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor(); - this.requestAttributeHandler = csrfTokenRequestProcessor; + csrfTokenRequestProcessor.setTokenRepository(csrfTokenRepository); + this.requestHandler = csrfTokenRequestProcessor; this.requestResolver = csrfTokenRequestProcessor; } @@ -108,15 +106,7 @@ public final class CsrfFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - request.setAttribute(HttpServletResponse.class.getName(), response); - CsrfToken csrfToken = this.tokenRepository.loadToken(request); - boolean missingToken = (csrfToken == null); - if (missingToken) { - csrfToken = this.tokenRepository.generateToken(request); - this.tokenRepository.saveToken(csrfToken, request, response); - } - final CsrfToken finalCsrfToken = csrfToken; - this.requestAttributeHandler.handle(request, response, () -> finalCsrfToken); + DeferredCsrfToken deferredCsrfToken = this.requestHandler.handle(request, response); if (!this.requireCsrfProtectionMatcher.matches(request)) { if (this.logger.isTraceEnabled()) { this.logger.trace("Did not protect against CSRF since request did not match " @@ -125,8 +115,10 @@ public final class CsrfFilter extends OncePerRequestFilter { filterChain.doFilter(request, response); return; } + CsrfToken csrfToken = deferredCsrfToken.get(); String actualToken = this.requestResolver.resolveCsrfTokenValue(request, csrfToken); if (!equalsConstantTime(csrfToken.getToken(), actualToken)) { + boolean missingToken = deferredCsrfToken.isGenerated(); this.logger.debug( LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request))); AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken) @@ -173,18 +165,18 @@ public final class CsrfFilter extends OncePerRequestFilter { } /** - * Specifies a {@link CsrfTokenRequestAttributeHandler} that is used to make the + * Specifies a {@link CsrfTokenRequestHandler} that is used to make the * {@link CsrfToken} available as a request attribute. * *

* The default is {@link CsrfTokenRequestProcessor}. *

- * @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use + * @param requestHandler the {@link CsrfTokenRequestHandler} to use * @since 5.8 */ - public void setRequestAttributeHandler(CsrfTokenRequestAttributeHandler requestAttributeHandler) { - Assert.notNull(requestAttributeHandler, "requestAttributeHandler cannot be null"); - this.requestAttributeHandler = requestAttributeHandler; + public void setRequestHandler(CsrfTokenRequestHandler requestHandler) { + Assert.notNull(requestHandler, "requestHandler cannot be null"); + this.requestHandler = requestHandler; } /** diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java similarity index 72% rename from web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java rename to web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java index a22f3144d2..7971fe61a8 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java @@ -16,14 +16,12 @@ package org.springframework.security.web.csrf; -import java.util.function.Supplier; - import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; /** - * A callback interface that is used to make the {@link CsrfToken} created by the - * {@link CsrfTokenRepository} available as a request attribute. Implementations of this + * A callback interface that is used to determine the {@link CsrfToken} to use and make + * the {@link CsrfToken} available as a request attribute. Implementations of this * interface may choose to perform additional tasks or customize how the token is made * available to the application through request attributes. * @@ -32,14 +30,13 @@ import javax.servlet.http.HttpServletResponse; * @see CsrfTokenRequestProcessor */ @FunctionalInterface -public interface CsrfTokenRequestAttributeHandler { +public interface CsrfTokenRequestHandler { /** * Handles a request using a {@link CsrfToken}. * @param request the {@code HttpServletRequest} being handled * @param response the {@code HttpServletResponse} being handled - * @param csrfToken the {@link CsrfToken} created by the {@link CsrfTokenRepository} */ - void handle(HttpServletRequest request, HttpServletResponse response, Supplier csrfToken); + DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java index 47807a1dc1..a1d455f4cb 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java @@ -24,7 +24,7 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.util.Assert; /** - * An implementation of the {@link CsrfTokenRequestAttributeHandler} and + * An implementation of the {@link CsrfTokenRequestHandler} and * {@link CsrfTokenRequestResolver} interfaces that is capable of making the * {@link CsrfToken} available as a request attribute and resolving the token value as * either a header or parameter value of the request. @@ -32,10 +32,22 @@ import org.springframework.util.Assert; * @author Steve Riesenberg * @since 5.8 */ -public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandler, CsrfTokenRequestResolver { +public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfTokenRequestResolver { private String csrfRequestAttributeName; + private CsrfTokenRepository tokenRepository = new HttpSessionCsrfTokenRepository(); + + /** + * Sets the {@link CsrfTokenRepository} to use. + * @param tokenRepository the {@link CsrfTokenRepository} to use. Default + * {@link HttpSessionCsrfTokenRepository} + */ + public void setTokenRepository(CsrfTokenRepository tokenRepository) { + Assert.notNull(tokenRepository, "tokenRepository cannot be null"); + this.tokenRepository = tokenRepository; + } + /** * The {@link CsrfToken} is available as a request attribute named * {@code CsrfToken.class.getName()}. By default, an additional request attribute that @@ -49,16 +61,18 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandl } @Override - public void handle(HttpServletRequest request, HttpServletResponse response, Supplier csrfToken) { + public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) { Assert.notNull(request, "request cannot be null"); Assert.notNull(response, "response cannot be null"); - Assert.notNull(csrfToken, "csrfToken supplier cannot be null"); - CsrfToken actualCsrfToken = csrfToken.get(); - Assert.notNull(actualCsrfToken, "csrfToken cannot be null"); - request.setAttribute(CsrfToken.class.getName(), actualCsrfToken); + + request.setAttribute(HttpServletResponse.class.getName(), response); + DeferredCsrfToken deferredCsrfToken = new RepositoryDeferredCsrfToken(request, response); + CsrfToken csrfToken = new SupplierCsrfToken(deferredCsrfToken::get); + request.setAttribute(CsrfToken.class.getName(), csrfToken); String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName - : actualCsrfToken.getParameterName(); - request.setAttribute(csrfAttrName, actualCsrfToken); + : csrfToken.getParameterName(); + request.setAttribute(csrfAttrName, csrfToken); + return deferredCsrfToken; } @Override @@ -72,4 +86,78 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandl return actualToken; } + private static final class SupplierCsrfToken implements CsrfToken { + + private final Supplier csrfTokenSupplier; + + private SupplierCsrfToken(Supplier csrfTokenSupplier) { + this.csrfTokenSupplier = csrfTokenSupplier; + } + + @Override + public String getHeaderName() { + return getDelegate().getHeaderName(); + } + + @Override + public String getParameterName() { + return getDelegate().getParameterName(); + } + + @Override + public String getToken() { + return getDelegate().getToken(); + } + + private CsrfToken getDelegate() { + CsrfToken delegate = this.csrfTokenSupplier.get(); + if (delegate == null) { + throw new IllegalStateException("csrfTokenSupplier returned null delegate"); + } + return delegate; + } + + } + + private final class RepositoryDeferredCsrfToken implements DeferredCsrfToken { + + private final HttpServletRequest request; + + private final HttpServletResponse response; + + private CsrfToken csrfToken; + + private Boolean missingToken; + + RepositoryDeferredCsrfToken(HttpServletRequest request, HttpServletResponse response) { + this.request = request; + this.response = response; + } + + @Override + public CsrfToken get() { + init(); + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + init(); + return this.missingToken; + } + + private void init() { + if (this.csrfToken != null) { + return; + } + this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.loadToken(this.request); + this.missingToken = (this.csrfToken == null); + if (this.missingToken) { + this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.generateToken(this.request); + CsrfTokenRequestProcessor.this.tokenRepository.saveToken(this.csrfToken, this.request, this.response); + } + } + + } + } diff --git a/web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java new file mode 100644 index 0000000000..d8ab774570 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +/** + * An interface that allows delayed access to a {@link CsrfToken} that may be generated. + * + * @author Rob Winch + * @since 5.8 + */ +public interface DeferredCsrfToken { + + /*** + * Gets the {@link CsrfToken} + * @return a non-null {@link CsrfToken} + */ + CsrfToken get(); + + /** + * Returns true if {@link #get()} refers to a generated {@link CsrfToken} or false if + * it already existed. + * @return true if {@link #get()} refers to a generated {@link CsrfToken} or false if + * it already existed. + */ + boolean isGenerated(); + +} diff --git a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java index a9817cc9c7..066723c189 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java @@ -27,7 +27,10 @@ import org.springframework.util.Assert; * * @author Rob Winch * @since 4.1 + * @deprecated Use org.springframework.security.web.csrf.CsrfTokenRequestHandler which + * returns a {@link DeferredCsrfToken} */ +@Deprecated public final class LazyCsrfTokenRepository implements CsrfTokenRepository { /** diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java index 9872522aa0..baa7e40b01 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java @@ -75,27 +75,24 @@ public class CsrfAuthenticationStrategyTests { } @Test - public void setRequestAttributeHandlerWhenNullThenIllegalStateException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setRequestAttributeHandler(null)) - .withMessage("requestAttributeHandler cannot be null"); + public void setRequestHandlerWhenNullThenIllegalStateException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setRequestHandler(null)) + .withMessage("requestHandler cannot be null"); } @Test - public void onAuthenticationWhenCustomRequestAttributeHandlerThenUsed() { - given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken); - given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); - - CsrfTokenRequestAttributeHandler requestAttributeHandler = mock(CsrfTokenRequestAttributeHandler.class); - this.strategy.setRequestAttributeHandler(requestAttributeHandler); + public void onAuthenticationWhenCustomRequestHandlerThenUsed() { + CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); + this.strategy.setRequestHandler(requestHandler); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); - verify(requestAttributeHandler).handle(eq(this.request), eq(this.response), any()); - verifyNoMoreInteractions(requestAttributeHandler); + verify(requestHandler).handle(eq(this.request), eq(this.response)); + verifyNoMoreInteractions(requestHandler); } @Test public void logoutRemovesCsrfTokenAndSavesNew() { - given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken); + given(this.csrfTokenRepository.loadToken(this.request)).willReturn(null, this.existingToken); given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); @@ -114,7 +111,6 @@ public class CsrfAuthenticationStrategyTests { @Test public void delaySavingCsrf() { this.strategy = new CsrfAuthenticationStrategy(new LazyCsrfTokenRepository(this.csrfTokenRepository)); - given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken); given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); @@ -128,10 +124,11 @@ public class CsrfAuthenticationStrategyTests { } @Test - public void logoutRemovesNoActionIfNullToken() { + public void logoutWhenNoCsrfToken() { + given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); - verify(this.csrfTokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), + verify(this.csrfTokenRepository).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index 443375c35a..15628d425c 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -24,8 +24,6 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.assertj.core.api.AbstractObjectAssert; -import org.assertj.core.api.ObjectAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -46,10 +44,12 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; /** * @author Rob Winch @@ -126,8 +126,8 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -138,8 +138,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -150,8 +150,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -164,8 +164,8 @@ public class CsrfFilterTests { this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -175,8 +175,8 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(false); given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -186,8 +186,8 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(false); given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -198,8 +198,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -212,8 +212,8 @@ public class CsrfFilterTests { this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -224,8 +224,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), @@ -238,8 +238,8 @@ public class CsrfFilterTests { given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); // LazyCsrfTokenRepository requires the response as an attribute assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); verify(this.filterChain).doFilter(this.request, this.response); @@ -304,8 +304,8 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); verifyNoMoreInteractions(this.filterChain); } @@ -336,14 +336,14 @@ public class CsrfFilterTests { } @Test - public void doFilterWhenRequestAttributeHandlerThenUsed() throws Exception { - given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); - CsrfTokenRequestAttributeHandler requestAttributeHandler = mock(CsrfTokenRequestAttributeHandler.class); - this.filter.setRequestAttributeHandler(requestAttributeHandler); + public void doFilterWhenRequestHandlerThenUsed() throws Exception { + CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); + given(requestHandler.handle(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); + this.filter.setRequestHandler(requestHandler); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - verify(requestAttributeHandler).handle(eq(this.request), eq(this.response), any()); + verify(requestHandler).handle(eq(this.request), eq(this.response)); verify(this.filterChain).doFilter(this.request, this.response); } @@ -376,39 +376,40 @@ public class CsrfFilterTests { CsrfFilter filter = createCsrfFilter(this.tokenRepository); String csrfAttrName = "_csrf"; CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor(); + csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository); csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName); - filter.setRequestAttributeHandler(csrfTokenRequestProcessor); - CsrfToken expectedCsrfToken = mock(CsrfToken.class); + filter.setRequestHandler(csrfTokenRequestProcessor); + CsrfToken expectedCsrfToken = spy(this.token); given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken); filter.doFilter(this.request, this.response, this.filterChain); verifyNoInteractions(expectedCsrfToken); CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); - assertThat(tokenFromRequest).isEqualTo(expectedCsrfToken); + assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken); } - private static CsrfTokenAssert assertToken(Object token) { - return new CsrfTokenAssert((CsrfToken) token); - } + private static final class TestDeferredCsrfToken implements DeferredCsrfToken { - private static class CsrfTokenAssert extends AbstractObjectAssert { + private final CsrfToken csrfToken; - /** - * Creates a new {@link ObjectAssert}. - * @param actual the target to verify. - */ - protected CsrfTokenAssert(CsrfToken actual) { - super(actual, CsrfTokenAssert.class); + private final boolean isGenerated; + + private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) { + this.csrfToken = csrfToken; + this.isGenerated = isGenerated; } - CsrfTokenAssert isEqualTo(CsrfToken expected) { - assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName()); - assertThat(this.actual.getParameterName()).isEqualTo(expected.getParameterName()); - assertThat(this.actual.getToken()).isEqualTo(expected.getToken()); - return this; + @Override + public CsrfToken get() { + return this.csrfToken; } - } + @Override + public boolean isGenerated() { + return this.isGenerated; + } + + }; } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenAssert.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenAssert.java new file mode 100644 index 0000000000..cca2591110 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenAssert.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.Assertions; + +/** + * Assertion for validating the properties on CsrfToken are the same. + */ +public class CsrfTokenAssert extends AbstractAssert { + + protected CsrfTokenAssert(CsrfToken csrfToken) { + super(csrfToken, CsrfTokenAssert.class); + } + + public static CsrfTokenAssert assertThatCsrfToken(Object csrfToken) { + return new CsrfTokenAssert((CsrfToken) csrfToken); + } + + public static CsrfTokenAssert assertThat(CsrfToken csrfToken) { + return new CsrfTokenAssert(csrfToken); + } + + public CsrfTokenAssert isEqualTo(CsrfToken csrfToken) { + isNotNull(); + assertThat(csrfToken).isNotNull(); + Assertions.assertThat(this.actual.getHeaderName()).isEqualTo(csrfToken.getHeaderName()); + Assertions.assertThat(this.actual.getParameterName()).isEqualTo(csrfToken.getParameterName()); + Assertions.assertThat(this.actual.getToken()).isEqualTo(csrfToken.getToken()); + return this; + } + +} diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java index ac50ec3aaa..390305954a 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java @@ -18,12 +18,17 @@ package org.springframework.security.web.csrf; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; /** * Tests for {@link CsrfTokenRequestProcessor}. @@ -31,8 +36,12 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException * @author Steve Riesenberg * @since 5.8 */ +@ExtendWith(MockitoExtension.class) public class CsrfTokenRequestProcessorTests { + @Mock + CsrfTokenRepository tokenRepository; + private MockHttpServletRequest request; private MockHttpServletResponse response; @@ -47,48 +56,36 @@ public class CsrfTokenRequestProcessorTests { this.response = new MockHttpServletResponse(); this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); this.processor = new CsrfTokenRequestProcessor(); + this.processor.setTokenRepository(this.tokenRepository); } @Test public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.processor.handle(null, this.response, () -> this.token)) + assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(null, this.response)) .withMessage("request cannot be null"); } @Test public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.processor.handle(this.request, null, () -> this.token)) + assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, null)) .withMessage("response cannot be null"); } - @Test - public void handleWhenCsrfTokenSupplierIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, this.response, null)) - .withMessage("csrfToken supplier cannot be null"); - } - - @Test - public void handleWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.processor.handle(this.request, this.response, () -> null)) - .withMessage("csrfToken cannot be null"); - } - @Test public void handleWhenCsrfRequestAttributeSetThenUsed() { + given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); this.processor.setCsrfRequestAttributeName("_csrf"); - this.processor.handle(this.request, this.response, () -> this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); - assertThat(this.request.getAttribute("_csrf")).isEqualTo(this.token); + this.processor.handle(this.request, this.response); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token); } @Test public void handleWhenValidParametersThenRequestAttributesSet() { - this.processor.handle(this.request, this.response, () -> this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + this.processor.handle(this.request, this.response); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); } @Test