From 475b3bb6bbdc737ba67820d8ca4e557b3e030e7c Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Tue, 27 Sep 2022 14:53:54 -0500 Subject: [PATCH] Add deferred CsrfTokenRepository.loadDeferredToken * Move DeferredCsrfToken to top-level and implement Supplier * Move RepositoryDeferredCsrfToken to top-level and make package-private * Add CsrfTokenRepository.loadToken(HttpServletRequest, HttpServletResponse) * Update CsrfFilter * Rename CsrfTokenRepositoryRequestHandler to CsrfTokenRequestAttributeHandler Issue gh-11892 Closes gh-11918 --- .../web/configurers/CsrfConfigurer.java | 12 +- .../config/http/CsrfBeanDefinitionParser.java | 16 +-- .../security/config/spring-security-5.8.rnc | 2 +- .../security/config/spring-security-5.8.xsd | 2 +- .../DeferHttpSessionJavaConfigTests.java | 4 +- .../web/configurers/CsrfConfigurerTests.java | 99 ++++++++++------ .../security/config/http/CsrfConfigTests.java | 10 +- .../CsrfConfigTests-WithRequestAttrName.xml | 2 +- .../http/DeferHttpSessionTests-Explicit.xml | 2 +- .../servlet/appendix/namespace/http.adoc | 2 +- .../SecurityMockMvcRequestPostProcessors.java | 85 +++++++------- .../test/web/support/WebTestUtils.java | 16 ++- ...yMockMvcRequestBuildersFormLoginTests.java | 6 +- ...MockMvcRequestBuildersFormLogoutTests.java | 6 +- ...estPostProcessorsCsrfDebugFilterTests.java | 74 ++++++++++++ .../test/web/support/WebTestUtilsTests.java | 15 ++- .../web/csrf/CsrfAuthenticationStrategy.java | 17 +-- .../security/web/csrf/CsrfFilter.java | 39 ++++--- .../web/csrf/CsrfTokenRepository.java | 19 ++- ... => CsrfTokenRequestAttributeHandler.java} | 71 +---------- .../web/csrf/CsrfTokenRequestHandler.java | 15 ++- .../web/csrf/CsrfTokenRequestResolver.java | 2 +- .../security/web/csrf/DeferredCsrfToken.java | 3 +- .../web/csrf/LazyCsrfTokenRepository.java | 5 +- .../web/csrf/RepositoryDeferredCsrfToken.java | 71 +++++++++++ .../csrf/CookieCsrfTokenRepositoryTests.java | 30 ++++- .../csrf/CsrfAuthenticationStrategyTests.java | 27 ++--- .../security/web/csrf/CsrfFilterTests.java | 110 ++++++++---------- ...srfTokenRequestAttributeHandlerTests.java} | 64 +++++----- .../HttpSessionCsrfTokenRepositoryTests.java | 23 +++- .../web/csrf/TestDeferredCsrfToken.java | 40 +++++++ 31 files changed, 536 insertions(+), 353 deletions(-) create mode 100644 test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java rename web/src/main/java/org/springframework/security/web/csrf/{CsrfTokenRepositoryRequestHandler.java => CsrfTokenRequestAttributeHandler.java} (59%) create mode 100644 web/src/main/java/org/springframework/security/web/csrf/RepositoryDeferredCsrfToken.java rename web/src/test/java/org/springframework/security/web/csrf/{CsrfTokenRepositoryRequestHandlerTests.java => CsrfTokenRequestAttributeHandlerTests.java} (72%) create mode 100644 web/src/test/java/org/springframework/security/web/csrf/TestDeferredCsrfToken.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java index 3f0ff7a477..f2c6e41707 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java @@ -36,7 +36,6 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfLogoutHandler; import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.LazyCsrfTokenRepository; @@ -249,13 +248,7 @@ public final class CsrfConfigurer> @SuppressWarnings("unchecked") @Override public void configure(H http) { - CsrfFilter filter; - if (this.requestHandler != null) { - filter = new CsrfFilter(this.requestHandler); - } - else { - filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.csrfTokenRepository)); - } + CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository); RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher(); if (requireCsrfProtectionMatcher != null) { filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); @@ -272,6 +265,9 @@ public final class CsrfConfigurer> if (sessionConfigurer != null) { sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy()); } + if (this.requestHandler != null) { + filter.setRequestHandler(this.requestHandler); + } filter = postProcess(filter); http.addFilter(filter); } diff --git a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java index 1a7009edee..ac3f40dd94 100644 --- a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,6 @@ import org.springframework.security.web.access.DelegatingAccessDeniedHandler; import org.springframework.security.web.csrf.CsrfAuthenticationStrategy; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfLogoutHandler; -import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.LazyCsrfTokenRepository; import org.springframework.security.web.csrf.MissingCsrfTokenException; @@ -112,18 +111,13 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { new BeanComponentDefinition(lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef)); } BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(CsrfFilter.class); - if (!StringUtils.hasText(this.requestHandlerRef)) { - BeanDefinition csrfTokenRequestHandler = BeanDefinitionBuilder - .rootBeanDefinition(CsrfTokenRepositoryRequestHandler.class) - .addConstructorArgReference(this.csrfRepositoryRef).getBeanDefinition(); - builder.addConstructorArgValue(csrfTokenRequestHandler); - } - else { - builder.addConstructorArgReference(this.requestHandlerRef); - } + builder.addConstructorArgReference(this.csrfRepositoryRef); if (StringUtils.hasText(this.requestMatcherRef)) { builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef); } + if (StringUtils.hasText(this.requestHandlerRef)) { + builder.addPropertyReference("requestHandler", this.requestHandlerRef); + } this.csrfFilter = builder.getBeanDefinition(); return this.csrfFilter; } diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc index 5ee3a4b885..32f8f4b584 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc @@ -1152,7 +1152,7 @@ csrf-options.attlist &= ## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository. attribute token-repository-ref { xsd:token }? csrf-options.attlist &= - ## The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler. + ## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestAttributeHandler. attribute request-handler-ref { xsd:token }? headers = diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd index 9a61dd6d41..c616c29b82 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd @@ -3258,7 +3258,7 @@ - The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler. + The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestAttributeHandler. diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java index e1d1181357..fe5cca2527 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java @@ -33,7 +33,7 @@ import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.web.DefaultSecurityFilterChain; import org.springframework.security.web.FilterChainProxy; -import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; +import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.LazyCsrfTokenRepository; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; @@ -85,7 +85,7 @@ public class DeferHttpSessionJavaConfigTests { csrfRepository.setDeferLoadToken(true); HttpSessionRequestCache requestCache = new HttpSessionRequestCache(); requestCache.setMatchingRequestParameterName("continue"); - CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(); + CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler(); requestHandler.setCsrfRequestAttributeName("_csrf"); // @formatter:off http diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java index c02e05bdf3..ccfe5fd2bb 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java @@ -44,8 +44,10 @@ import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; +import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.firewall.StrictHttpFirewall; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -61,7 +63,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; @@ -207,30 +208,30 @@ public class CsrfConfigurerTests { public void loginWhenCsrfEnabledThenDoesNotRedirectToPreviousPostRequest() throws Exception { CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class); DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); - given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken); - given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken); + given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadDeferredToken(any(HttpServletRequest.class), + any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken)); this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire(); MvcResult mvcResult = this.mvc.perform(post("/some-url")).andReturn(); this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf()) .session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound()) .andExpect(redirectedUrl("/")); verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce()) - .loadToken(any(HttpServletRequest.class)); + .loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void loginWhenCsrfEnabledThenRedirectsToPreviousGetRequest() throws Exception { CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class); DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); - given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken); - given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken); + given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadDeferredToken(any(HttpServletRequest.class), + any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken)); this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire(); MvcResult mvcResult = this.mvc.perform(get("/some-url")).andReturn(); this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf()) .session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/some-url")); verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce()) - .loadToken(any(HttpServletRequest.class)); + .loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)); } // SEC-2422 @@ -277,11 +278,13 @@ public class CsrfConfigurerTests { @Test public void getWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception { CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); - given(CsrfTokenRepositoryConfig.REPO.loadToken(any())) - .willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")); + given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class), + any(HttpServletResponse.class))) + .willReturn(new TestDeferredCsrfToken(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"))); this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); this.mvc.perform(get("/")).andExpect(status().isOk()); - verify(CsrfTokenRepositoryConfig.REPO).loadToken(any(HttpServletRequest.class)); + verify(CsrfTokenRepositoryConfig.REPO).loadDeferredToken(any(HttpServletRequest.class), + any(HttpServletResponse.class)); } @Test @@ -297,8 +300,8 @@ public class CsrfConfigurerTests { public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception { CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); - given(CsrfTokenRepositoryConfig.REPO.loadToken(any())).willReturn(csrfToken); - given(CsrfTokenRepositoryConfig.REPO.generateToken(any())).willReturn(csrfToken); + given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class), + any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken)); this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); // @formatter:off MockHttpServletRequestBuilder loginRequest = post("/login") @@ -314,11 +317,13 @@ public class CsrfConfigurerTests { @Test public void getWhenCustomCsrfTokenRepositoryInLambdaThenRepositoryIsUsed() throws Exception { CsrfTokenRepositoryInLambdaConfig.REPO = mock(CsrfTokenRepository.class); - given(CsrfTokenRepositoryInLambdaConfig.REPO.loadToken(any())) - .willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")); + given(CsrfTokenRepositoryInLambdaConfig.REPO.loadDeferredToken(any(HttpServletRequest.class), + any(HttpServletResponse.class))) + .willReturn(new TestDeferredCsrfToken(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"))); this.spring.register(CsrfTokenRepositoryInLambdaConfig.class, BasicController.class).autowire(); this.mvc.perform(get("/")).andExpect(status().isOk()); - verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadToken(any(HttpServletRequest.class)); + verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadDeferredToken(any(HttpServletRequest.class), + any(HttpServletResponse.class)); } @Test @@ -418,30 +423,30 @@ public class CsrfConfigurerTests { } @Test - public void getLoginWhenCsrfTokenRequestProcessorSetThenRespondsWithNormalCsrfToken() throws Exception { + public void getLoginWhenCsrfTokenRequestHandlerSetThenRespondsWithNormalCsrfToken() throws Exception { CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); - given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken); - CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository); - this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire(); + given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))) + .willReturn(new TestDeferredCsrfToken(csrfToken)); + CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository; + CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler(); + this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire(); this.mvc.perform(get("/login")).andExpect(status().isOk()) .andExpect(content().string(containsString(csrfToken.getToken()))); - verify(csrfTokenRepository).loadToken(any(HttpServletRequest.class)); - verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class)); - verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class), - any(HttpServletResponse.class)); + verify(csrfTokenRepository).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)); verifyNoMoreInteractions(csrfTokenRepository); } @Test - public void loginWhenCsrfTokenRequestProcessorSetAndNormalCsrfTokenThenSuccess() throws Exception { + public void loginWhenCsrfTokenRequestHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception { CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); - given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken); - given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken); - CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository); + given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))) + .willReturn(new TestDeferredCsrfToken(csrfToken)); + CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository; + CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler(); + this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire(); - this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire(); // @formatter:off MockHttpServletRequestBuilder loginRequest = post("/login") .header(csrfToken.getHeaderName(), csrfToken.getToken()) @@ -449,9 +454,8 @@ public class CsrfConfigurerTests { .param("password", "password"); // @formatter:on this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); - verify(csrfTokenRepository, times(2)).loadToken(any(HttpServletRequest.class)); - verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class)); - verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class), + verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(csrfTokenRepository, times(2)).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)); verifyNoMoreInteractions(csrfTokenRepository); } @@ -799,9 +803,11 @@ public class CsrfConfigurerTests { @Configuration @EnableWebSecurity - static class CsrfTokenRequestProcessorConfig { + static class CsrfTokenRequestHandlerConfig { - static CsrfTokenRepositoryRequestHandler HANDLER; + static CsrfTokenRepository REPO; + + static CsrfTokenRequestHandler HANDLER; @Bean SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { @@ -811,7 +817,10 @@ public class CsrfConfigurerTests { .anyRequest().authenticated() ) .formLogin(Customizer.withDefaults()) - .csrf((csrf) -> csrf.csrfTokenRequestHandler(HANDLER)); + .csrf((csrf) -> csrf + .csrfTokenRepository(REPO) + .csrfTokenRequestHandler(HANDLER) + ); // @formatter:on return http.build(); @@ -841,4 +850,24 @@ public class CsrfConfigurerTests { } + private static final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + private TestDeferredCsrfToken(CsrfToken csrfToken) { + this.csrfToken = csrfToken; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return false; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java index 1f35855205..e9220895fb 100644 --- a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java @@ -30,7 +30,6 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpMethod; import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.config.test.SpringTestContext; @@ -42,7 +41,6 @@ import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfToken; -import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.stereotype.Controller; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -546,9 +544,8 @@ public class CsrfConfigTests { @Override public void match(MvcResult result) { MockHttpServletRequest request = result.getRequest(); - MockHttpServletResponse response = result.getResponse(); - DeferredCsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response); - assertThat(token.isGenerated()).isFalse(); + CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request); + assertThat(token).isNotNull(); } } @@ -564,8 +561,7 @@ public class CsrfConfigTests { @Override public void match(MvcResult result) throws Exception { MockHttpServletRequest request = result.getRequest(); - MockHttpServletResponse response = result.getResponse(); - CsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response).get(); + CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request); assertThat(token).isNotNull(); assertThat(token.getToken()).isEqualTo(this.token.apply(result)); } diff --git a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml index 2ea2dd7c87..e8ec4b8f01 100644 --- a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml +++ b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml @@ -26,7 +26,7 @@ - diff --git a/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml b/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml index e15d072002..cbfdfa90a3 100644 --- a/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml +++ b/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml @@ -42,7 +42,7 @@ - diff --git a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc index f660c8b9e1..cb5f296c5f 100644 --- a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc +++ b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc @@ -783,7 +783,7 @@ The default is `HttpSessionCsrfTokenRepository`. [[nsa-csrf-request-handler-ref]] * **request-handler-ref** -The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRepositoryRequestHandler`. +The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRequestAttributeHandler`. [[nsa-csrf-request-matcher-ref]] * **request-matcher-ref** diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 1419b0e497..4eb0f6b324 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -94,8 +94,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfToken; -import org.springframework.security.web.csrf.CsrfTokenRequestHandler; -import org.springframework.security.web.csrf.DeferredCsrfToken; +import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; @@ -509,13 +508,14 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request); - if (!(handler instanceof TestCsrfTokenRequestHandler)) { - handler = new TestCsrfTokenRequestHandler(handler); - WebTestUtils.setCsrfTokenRequestHandler(request, handler); + CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); + if (!(repository instanceof TestCsrfTokenRepository)) { + repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository()); + WebTestUtils.setCsrfTokenRepository(request, repository); } - TestCsrfTokenRequestHandler testHandler = (TestCsrfTokenRequestHandler) handler; - CsrfToken token = TestCsrfTokenRequestHandler.createTestCsrfToken(request); + TestCsrfTokenRepository.enable(request); + CsrfToken token = repository.generateToken(request); + repository.saveToken(token, request, new MockHttpServletResponse()); String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken(); if (this.asHeader) { request.addHeader(token.getHeaderName(), tokenValue); @@ -549,56 +549,49 @@ public final class SecurityMockMvcRequestPostProcessors { * Used to wrap the CsrfTokenRepository to provide support for testing when the * request is wrapped (i.e. Spring Session is in use). */ - static class TestCsrfTokenRequestHandler implements CsrfTokenRequestHandler { + static class TestCsrfTokenRepository implements CsrfTokenRepository { - static final String TOKEN_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".TOKEN"); + static final String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".TOKEN"); - static final String ENABLED_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".ENABLED"); + static final String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".ENABLED"); - private final CsrfTokenRequestHandler delegate; + private final CsrfTokenRepository delegate; - TestCsrfTokenRequestHandler(CsrfTokenRequestHandler delegate) { + TestCsrfTokenRepository(CsrfTokenRepository delegate) { this.delegate = delegate; } - static CsrfToken createTestCsrfToken(HttpServletRequest request) { - CsrfToken existingToken = getExistingToken(request); - if (existingToken != null) { - return existingToken; - } - HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); - CsrfToken csrfToken = repository.generateToken(request); - request.setAttribute(ENABLED_ATTR_NAME, true); - request.setAttribute(TOKEN_ATTR_NAME, csrfToken); - return csrfToken; - } - - private static CsrfToken getExistingToken(HttpServletRequest request) { - Object existingToken = request.getAttribute(TOKEN_ATTR_NAME); - return (CsrfToken) existingToken; - } - - boolean isEnabled(HttpServletRequest request) { - return getExistingToken(request) != null; + @Override + public CsrfToken generateToken(HttpServletRequest request) { + return this.delegate.generateToken(request); } @Override - public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) { - request.setAttribute(HttpServletResponse.class.getName(), response); - if (!isEnabled(request)) { - return this.delegate.handle(request, response); + public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { + if (isEnabled(request)) { + request.setAttribute(TOKEN_ATTR_NAME, token); } - return new DeferredCsrfToken() { - @Override - public CsrfToken get() { - return getExistingToken(request); - } + else { + this.delegate.saveToken(token, request, response); + } + } - @Override - public boolean isGenerated() { - return false; - } - }; + @Override + public CsrfToken loadToken(HttpServletRequest request) { + if (isEnabled(request)) { + return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME); + } + else { + return this.delegate.loadToken(request); + } + } + + static void enable(HttpServletRequest request) { + request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE); + } + + boolean isEnabled(HttpServletRequest request) { + return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME)); } } diff --git a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java index 5f2fefec13..c13ebdefe3 100644 --- a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java +++ b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java @@ -31,8 +31,6 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; -import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.web.context.WebApplicationContext; @@ -48,7 +46,7 @@ public abstract class WebTestUtils { private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository(); - private static final CsrfTokenRepositoryRequestHandler DEFAULT_CSRF_HANDLER = new CsrfTokenRepositoryRequestHandler(); + private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository(); private WebTestUtils() { } @@ -101,24 +99,24 @@ public abstract class WebTestUtils { * @return the {@link CsrfTokenRepository} for the specified * {@link HttpServletRequest} */ - public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) { + public static CsrfTokenRepository getCsrfTokenRepository(HttpServletRequest request) { CsrfFilter filter = findFilter(request, CsrfFilter.class); if (filter == null) { - return DEFAULT_CSRF_HANDLER; + return DEFAULT_TOKEN_REPO; } - return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler"); + return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository"); } /** * Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}. * @param request the {@link HttpServletRequest} to obtain the * {@link CsrfTokenRepository} - * @param handler the {@link CsrfTokenRepository} to set + * @param repository the {@link CsrfTokenRepository} to set */ - public static void setCsrfTokenRequestHandler(HttpServletRequest request, CsrfTokenRequestHandler handler) { + public static void setCsrfTokenRepository(HttpServletRequest request, CsrfTokenRepository repository) { CsrfFilter filter = findFilter(request, CsrfFilter.class); if (filter != null) { - ReflectionTestUtils.setField(filter, "requestHandler", handler); + ReflectionTestUtils.setField(filter, "tokenRepository", repository); } } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java index 374aa68414..9dea5175bf 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java @@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { public void defaults() { MockHttpServletRequest request = formLogin().buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("password")).isEqualTo("password"); assertThat(request.getMethod()).isEqualTo("POST"); @@ -67,7 +67,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret") .buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); @@ -80,7 +80,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2") .user("username", "admin").password("password", "secret").buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java index c6856fb821..df6e7cfef2 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java @@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { public void defaults() { MockHttpServletRequest request = logout().buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/logout"); @@ -63,7 +63,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { public void custom() { MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/admin/logout"); @@ -74,7 +74,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2") .buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2"); diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java new file mode 100644 index 0000000000..acb81a8134 --- /dev/null +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.test.web.servlet.request; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.builders.WebSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; +import org.springframework.security.test.web.support.WebTestUtils; +import org.springframework.security.web.csrf.CookieCsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.web.context.WebApplicationContext; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; + +@ExtendWith(SpringExtension.class) +@ContextConfiguration +@WebAppConfiguration +public class SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests { + + @Autowired + private WebApplicationContext wac; + + // SEC-3836 + @Test + public void findCookieCsrfTokenRepository() { + MockHttpServletRequest request = post("/").buildRequest(this.wac.getServletContext()); + CsrfTokenRepository csrfTokenRepository = WebTestUtils.getCsrfTokenRepository(request); + assertThat(csrfTokenRepository).isNotNull(); + assertThat(csrfTokenRepository).isEqualTo(Config.cookieCsrfTokenRepository); + } + + @EnableWebSecurity + static class Config extends WebSecurityConfigurerAdapter { + + static CsrfTokenRepository cookieCsrfTokenRepository = new CookieCsrfTokenRepository(); + + @Override + protected void configure(HttpSecurity http) throws Exception { + http.csrf().csrfTokenRepository(cookieCsrfTokenRepository); + } + + @Override + public void configure(WebSecurity web) { + // Enable the DebugFilter + web.debug(true); + } + + } + +} diff --git a/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java b/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java index 2dd0943024..c304202b4d 100644 --- a/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java +++ b/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java @@ -39,7 +39,6 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.web.context.WebApplicationContext; @@ -75,22 +74,22 @@ public class WebTestUtilsTests { @Test public void getCsrfTokenRepositorytNoWac() { - assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)) - .isInstanceOf(CsrfTokenRepositoryRequestHandler.class); + assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) + .isInstanceOf(HttpSessionCsrfTokenRepository.class); } @Test public void getCsrfTokenRepositorytNoSecurity() { loadConfig(Config.class); - assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)) - .isInstanceOf(CsrfTokenRepositoryRequestHandler.class); + assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) + .isInstanceOf(HttpSessionCsrfTokenRepository.class); } @Test public void getCsrfTokenRepositorytSecurityNoCsrf() { loadConfig(SecurityNoCsrfConfig.class); - assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)) - .isInstanceOf(CsrfTokenRepositoryRequestHandler.class); + assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) + .isInstanceOf(HttpSessionCsrfTokenRepository.class); } @Test @@ -98,7 +97,7 @@ public class WebTestUtilsTests { CustomSecurityConfig.CONTEXT_REPO = this.contextRepo; CustomSecurityConfig.CSRF_REPO = this.csrfRepo; loadConfig(CustomSecurityConfig.class); - // assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo); + assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo); } // getSecurityContextRepository diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java index 2c17d2edf6..552e468cf3 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java @@ -39,17 +39,17 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt private final Log logger = LogFactory.getLog(getClass()); - private final CsrfTokenRepository csrfTokenRepository; + private final CsrfTokenRepository tokenRepository; - private CsrfTokenRequestHandler requestHandler; + private CsrfTokenRequestHandler requestHandler = new CsrfTokenRequestAttributeHandler(); /** * Creates a new instance - * @param csrfTokenRepository the {@link CsrfTokenRepository} to use + * @param tokenRepository the {@link CsrfTokenRepository} to use */ - public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) { - this.requestHandler = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository); - this.csrfTokenRepository = csrfTokenRepository; + public CsrfAuthenticationStrategy(CsrfTokenRepository tokenRepository) { + Assert.notNull(tokenRepository, "tokenRepository cannot be null"); + this.tokenRepository = tokenRepository; } /** @@ -65,8 +65,9 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt @Override public void onAuthentication(Authentication authentication, HttpServletRequest request, HttpServletResponse response) throws SessionAuthenticationException { - this.csrfTokenRepository.saveToken(null, request, response); - this.requestHandler.handle(request, response); + this.tokenRepository.saveToken(null, request, response); + DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response); + this.requestHandler.handle(request, response, deferredCsrfToken::get); this.logger.debug("Replaced CSRF Token"); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 71761b9960..5f3b94b6c9 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -82,30 +82,21 @@ public final class CsrfFilter extends OncePerRequestFilter { private final Log logger = LogFactory.getLog(getClass()); - private final CsrfTokenRequestHandler requestHandler; + private final CsrfTokenRepository tokenRepository; private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER; private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); - /** - * Creates a new instance. - * @param csrfTokenRepository the {@link CsrfTokenRepository} to use - * @deprecated Use {@link CsrfFilter#CsrfFilter(CsrfTokenRequestHandler)} instead - */ - @Deprecated - public CsrfFilter(CsrfTokenRepository csrfTokenRepository) { - this(new CsrfTokenRepositoryRequestHandler(csrfTokenRepository)); - } + private CsrfTokenRequestHandler requestHandler = new CsrfTokenRequestAttributeHandler(); /** * Creates a new instance. - * @param requestHandler the {@link CsrfTokenRequestHandler} to use. Default is - * {@link CsrfTokenRepositoryRequestHandler}. + * @param tokenRepository the {@link CsrfTokenRepository} to use */ - public CsrfFilter(CsrfTokenRequestHandler requestHandler) { - Assert.notNull(requestHandler, "requestHandler cannot be null"); - this.requestHandler = requestHandler; + public CsrfFilter(CsrfTokenRepository tokenRepository) { + Assert.notNull(tokenRepository, "tokenRepository cannot be null"); + this.tokenRepository = tokenRepository; } @Override @@ -116,7 +107,8 @@ public final class CsrfFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - DeferredCsrfToken deferredCsrfToken = this.requestHandler.handle(request, response); + DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response); + this.requestHandler.handle(request, response, deferredCsrfToken::get); if (!this.requireCsrfProtectionMatcher.matches(request)) { if (this.logger.isTraceEnabled()) { this.logger.trace("Did not protect against CSRF since request did not match " @@ -174,6 +166,21 @@ public final class CsrfFilter extends OncePerRequestFilter { this.accessDeniedHandler = accessDeniedHandler; } + /** + * Specifies a {@link CsrfTokenRequestHandler} that is used to make the + * {@link CsrfToken} available as a request attribute. + * + *

+ * The default is {@link CsrfTokenRequestAttributeHandler}. + *

+ * @param requestHandler the {@link CsrfTokenRequestHandler} to use + * @since 5.8 + */ + public void setRequestHandler(CsrfTokenRequestHandler requestHandler) { + Assert.notNull(requestHandler, "requestHandler cannot be null"); + this.requestHandler = requestHandler; + } + /** * Constant time comparison to prevent against timing attacks. * @param expected diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java index bc20131685..792abf10a3 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import javax.servlet.http.HttpSession; * {@link HttpSession}. * * @author Rob Winch + * @author Steve Riesenberg * @since 3.2 * @see HttpSessionCsrfTokenRepository */ @@ -55,4 +56,20 @@ public interface CsrfTokenRepository { */ CsrfToken loadToken(HttpServletRequest request); + /** + * Defers loading the {@link CsrfToken} using the {@link HttpServletRequest} and + * {@link HttpServletResponse} until it is needed by the application. + *

+ * The returned {@link DeferredCsrfToken} is cached to allow subsequent calls to + * {@link DeferredCsrfToken#get()} to return the same {@link CsrfToken} without the + * cost of loading or generating the token again. + * @param request the {@link HttpServletRequest} to use + * @param response the {@link HttpServletResponse} to use + * @return a {@link DeferredCsrfToken} that will load the {@link CsrfToken} + * @since 5.8 + */ + default DeferredCsrfToken loadDeferredToken(HttpServletRequest request, HttpServletResponse response) { + return new RepositoryDeferredCsrfToken(this, request, response); + } + } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandler.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java similarity index 59% rename from web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandler.java rename to web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java index eef2d8976c..f897a5659c 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandler.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java @@ -31,29 +31,10 @@ import org.springframework.util.Assert; * @author Steve Riesenberg * @since 5.8 */ -public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandler { - - private final CsrfTokenRepository csrfTokenRepository; +public class CsrfTokenRequestAttributeHandler implements CsrfTokenRequestHandler { private String csrfRequestAttributeName; - /** - * Creates a new instance. - */ - public CsrfTokenRepositoryRequestHandler() { - this(new HttpSessionCsrfTokenRepository()); - } - - /** - * Creates a new instance. - * @param csrfTokenRepository the {@link CsrfTokenRepository} to use. Default - * {@link HttpSessionCsrfTokenRepository} - */ - public CsrfTokenRepositoryRequestHandler(CsrfTokenRepository csrfTokenRepository) { - Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); - this.csrfTokenRepository = csrfTokenRepository; - } - /** * The {@link CsrfToken} is available as a request attribute named * {@code CsrfToken.class.getName()}. By default, an additional request attribute that @@ -67,18 +48,18 @@ public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandle } @Override - public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) { + public void handle(HttpServletRequest request, HttpServletResponse response, + Supplier deferredCsrfToken) { Assert.notNull(request, "request cannot be null"); Assert.notNull(response, "response cannot be null"); + Assert.notNull(deferredCsrfToken, "deferredCsrfToken cannot be null"); request.setAttribute(HttpServletResponse.class.getName(), response); - DeferredCsrfToken deferredCsrfToken = new RepositoryDeferredCsrfToken(request, response); - CsrfToken csrfToken = new SupplierCsrfToken(deferredCsrfToken::get); + CsrfToken csrfToken = new SupplierCsrfToken(deferredCsrfToken); request.setAttribute(CsrfToken.class.getName(), csrfToken); String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName : csrfToken.getParameterName(); request.setAttribute(csrfAttrName, csrfToken); - return deferredCsrfToken; } private static final class SupplierCsrfToken implements CsrfToken { @@ -114,46 +95,4 @@ public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandle } - private final class RepositoryDeferredCsrfToken implements DeferredCsrfToken { - - private final HttpServletRequest request; - - private final HttpServletResponse response; - - private CsrfToken csrfToken; - - private Boolean missingToken; - - RepositoryDeferredCsrfToken(HttpServletRequest request, HttpServletResponse response) { - this.request = request; - this.response = response; - } - - @Override - public CsrfToken get() { - init(); - return this.csrfToken; - } - - @Override - public boolean isGenerated() { - init(); - return this.missingToken; - } - - private void init() { - if (this.csrfToken != null) { - return; - } - this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.loadToken(this.request); - this.missingToken = (this.csrfToken == null); - if (this.missingToken) { - this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.generateToken(this.request); - CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.saveToken(this.csrfToken, this.request, - this.response); - } - } - - } - } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java index 495756dbe7..aa30162bcd 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java @@ -16,20 +16,22 @@ package org.springframework.security.web.csrf; +import java.util.function.Supplier; + import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.springframework.util.Assert; /** - * An interface that is used to determine the {@link CsrfToken} to use and make the - * {@link CsrfToken} available as a request attribute. Implementations of this interface - * may choose to perform additional tasks or customize how the token is made available to - * the application through request attributes. + * A callback interface that is used to make the {@link CsrfToken} created by the + * {@link CsrfTokenRepository} available as a request attribute. Implementations of this + * interface may choose to perform additional tasks or customize how the token is made + * available to the application through request attributes. * * @author Steve Riesenberg * @since 5.8 - * @see CsrfTokenRepositoryRequestHandler + * @see CsrfTokenRequestAttributeHandler */ @FunctionalInterface public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver { @@ -38,8 +40,9 @@ public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver { * Handles a request using a {@link CsrfToken}. * @param request the {@code HttpServletRequest} being handled * @param response the {@code HttpServletResponse} being handled + * @param csrfToken the {@link CsrfToken} created by the {@link CsrfTokenRepository} */ - DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response); + void handle(HttpServletRequest request, HttpServletResponse response, Supplier csrfToken); @Override default String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) { diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java index e19ff3fbbb..d9fb93e50e 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java @@ -25,7 +25,7 @@ import javax.servlet.http.HttpServletRequest; * * @author Steve Riesenberg * @since 5.8 - * @see CsrfTokenRepositoryRequestHandler + * @see CsrfTokenRequestAttributeHandler */ @FunctionalInterface public interface CsrfTokenRequestResolver { diff --git a/web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java index d8ab774570..a27a31f7c1 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java +++ b/web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java @@ -20,11 +20,12 @@ package org.springframework.security.web.csrf; * An interface that allows delayed access to a {@link CsrfToken} that may be generated. * * @author Rob Winch + * @author Steve Riesenberg * @since 5.8 */ public interface DeferredCsrfToken { - /*** + /** * Gets the {@link CsrfToken} * @return a non-null {@link CsrfToken} */ diff --git a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java index 066723c189..692e002e62 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java @@ -27,8 +27,9 @@ import org.springframework.util.Assert; * * @author Rob Winch * @since 4.1 - * @deprecated Use org.springframework.security.web.csrf.CsrfTokenRequestHandler which - * returns a {@link DeferredCsrfToken} + * @deprecated Use + * {@link CsrfTokenRepository#loadDeferredToken(HttpServletRequest, HttpServletResponse)} + * which returns a {@link DeferredCsrfToken} */ @Deprecated public final class LazyCsrfTokenRepository implements CsrfTokenRepository { diff --git a/web/src/main/java/org/springframework/security/web/csrf/RepositoryDeferredCsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/RepositoryDeferredCsrfToken.java new file mode 100644 index 0000000000..2f885a216a --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/csrf/RepositoryDeferredCsrfToken.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * @author Rob Winch + * @author Steve Riesenberg + * @since 5.8 + */ +final class RepositoryDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfTokenRepository csrfTokenRepository; + + private final HttpServletRequest request; + + private final HttpServletResponse response; + + private CsrfToken csrfToken; + + private boolean missingToken; + + RepositoryDeferredCsrfToken(CsrfTokenRepository csrfTokenRepository, HttpServletRequest request, + HttpServletResponse response) { + this.csrfTokenRepository = csrfTokenRepository; + this.request = request; + this.response = response; + } + + @Override + public CsrfToken get() { + init(); + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + init(); + return this.missingToken; + } + + private void init() { + if (this.csrfToken != null) { + return; + } + + this.csrfToken = this.csrfTokenRepository.loadToken(this.request); + this.missingToken = (this.csrfToken == null); + if (this.missingToken) { + this.csrfToken = this.csrfTokenRepository.generateToken(this.request); + this.csrfTokenRepository.saveToken(this.csrfToken, this.request, this.response); + } + } + +} diff --git a/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java index 3b08d4318f..093357699f 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import org.springframework.mock.web.MockHttpServletResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; /** * @author Rob Winch @@ -246,6 +247,33 @@ public class CookieCsrfTokenRepositoryTests { assertThat(loadToken.getToken()).isEqualTo(value); } + @Test + public void loadDeferredTokenWhenDoesNotExistThenGeneratedAndSaved() { + DeferredCsrfToken deferredCsrfToken = this.repository.loadDeferredToken(this.request, this.response); + CsrfToken csrfToken = deferredCsrfToken.get(); + assertThat(csrfToken).isNotNull(); + assertThat(deferredCsrfToken.isGenerated()).isTrue(); + Cookie tokenCookie = this.response.getCookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME); + assertThat(tokenCookie).isNotNull(); + assertThat(tokenCookie.getMaxAge()).isEqualTo(-1); + assertThat(tokenCookie.getName()).isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME); + assertThat(tokenCookie.getPath()).isEqualTo(this.request.getContextPath()); + assertThat(tokenCookie.getSecure()).isEqualTo(this.request.isSecure()); + assertThat(tokenCookie.getValue()).isEqualTo(csrfToken.getToken()); + assertThat(tokenCookie.isHttpOnly()).isEqualTo(true); + } + + @Test + public void loadDeferredTokenWhenExistsThenLoaded() { + CsrfToken generatedToken = this.repository.generateToken(this.request); + this.request + .setCookies(new Cookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME, generatedToken.getToken())); + DeferredCsrfToken deferredCsrfToken = this.repository.loadDeferredToken(this.request, this.response); + CsrfToken csrfToken = deferredCsrfToken.get(); + assertThatCsrfToken(csrfToken).isEqualTo(generatedToken); + assertThat(deferredCsrfToken.isGenerated()).isFalse(); + } + @Test public void setCookieNameNullIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setCookieName(null)); diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java index baa7e40b01..5342cae902 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -82,23 +82,25 @@ public class CsrfAuthenticationStrategyTests { @Test public void onAuthenticationWhenCustomRequestHandlerThenUsed() { + given(this.csrfTokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.existingToken, false)); + CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); this.strategy.setRequestHandler(requestHandler); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); - verify(requestHandler).handle(eq(this.request), eq(this.response)); + verify(requestHandler).handle(eq(this.request), eq(this.response), any()); verifyNoMoreInteractions(requestHandler); } @Test - public void logoutRemovesCsrfTokenAndSavesNew() { - given(this.csrfTokenRepository.loadToken(this.request)).willReturn(null, this.existingToken); - given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); + public void logoutRemovesCsrfTokenAndLoadsNewDeferredCsrfToken() { + given(this.csrfTokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.generatedToken, false)); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); verify(this.csrfTokenRepository).saveToken(null, this.request, this.response); - verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class), - any(HttpServletResponse.class)); + verify(this.csrfTokenRepository).loadDeferredToken(this.request, this.response); // SEC-2404, SEC-2832 CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); assertThat(tokenInRequest.getToken()).isSameAs(this.generatedToken.getToken()); @@ -119,17 +121,10 @@ public class CsrfAuthenticationStrategyTests { any(HttpServletResponse.class)); CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); tokenInRequest.getToken(); + verify(this.csrfTokenRepository).loadToken(this.request); + verify(this.csrfTokenRepository).generateToken(this.request); verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class)); } - @Test - public void logoutWhenNoCsrfToken() { - given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); - this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, - this.response); - verify(this.csrfTokenRepository).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), - any(HttpServletResponse.class)); - } - } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index 81b51cf191..8b0fc1449a 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -44,7 +44,6 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -86,11 +85,7 @@ public class CsrfFilterTests { } private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) { - return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository)); - } - - private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) { - CsrfFilter filter = new CsrfFilter(requestHandler); + CsrfFilter filter = new CsrfFilter(repository); filter.setRequireCsrfProtectionMatcher(this.requestMatcher); filter.setAccessDeniedHandler(this.deniedHandler); return filter; @@ -103,7 +98,7 @@ public class CsrfFilterTests { @Test public void constructorNullRepository() { - assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null)); + assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null)); } // SEC-2276 @@ -128,7 +123,8 @@ public class CsrfFilterTests { @Test public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); @@ -139,7 +135,8 @@ public class CsrfFilterTests { @Test public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); @@ -151,7 +148,8 @@ public class CsrfFilterTests { @Test public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); @@ -164,7 +162,8 @@ public class CsrfFilterTests { public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); @@ -177,7 +176,8 @@ public class CsrfFilterTests { @Test public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); @@ -188,7 +188,8 @@ public class CsrfFilterTests { @Test public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, true)); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); @@ -199,7 +200,8 @@ public class CsrfFilterTests { @Test public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); @@ -212,7 +214,8 @@ public class CsrfFilterTests { public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); @@ -225,7 +228,8 @@ public class CsrfFilterTests { @Test public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); @@ -239,7 +243,8 @@ public class CsrfFilterTests { @Test public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, true)); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); @@ -247,17 +252,17 @@ public class CsrfFilterTests { // LazyCsrfTokenRepository requires the response as an attribute assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); verify(this.filterChain).doFilter(this.request, this.response); - verify(this.tokenRepository).saveToken(this.token, this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @Test public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { - this.filter = createCsrfFilter(this.tokenRepository); + this.filter = new CsrfFilter(this.tokenRepository); this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { resetRequestResponse(); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.setMethod(method); this.filter.doFilter(this.request, this.response, this.filterChain); verify(this.filterChain).doFilter(this.request, this.response); @@ -273,11 +278,12 @@ public class CsrfFilterTests { */ @Test public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception { - this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); + this.filter = new CsrfFilter(this.tokenRepository); this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) { resetRequestResponse(); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.setMethod(method); this.filter.doFilter(this.request, this.response, this.filterChain); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), @@ -288,11 +294,12 @@ public class CsrfFilterTests { @Test public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { - this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); + this.filter = new CsrfFilter(this.tokenRepository); this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) { resetRequestResponse(); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.setMethod(method); this.filter.doFilter(this.request, this.response, this.filterChain); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), @@ -303,10 +310,11 @@ public class CsrfFilterTests { @Test public void doFilterDefaultAccessDenied() throws ServletException, IOException { - this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); + this.filter = new CsrfFilter(this.tokenRepository); this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); @@ -317,7 +325,7 @@ public class CsrfFilterTests { @Test public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception { CsrfTokenRepository repository = mock(CsrfTokenRepository.class); - CsrfFilter filter = createCsrfFilter(repository); + CsrfFilter filter = new CsrfFilter(repository); lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token); MockHttpServletRequest request = new MockHttpServletRequest(); CsrfFilter.skipRequest(request); @@ -333,7 +341,8 @@ public class CsrfFilterTests { given(token.getToken()).willReturn(null); given(token.getHeaderName()).willReturn(this.token.getHeaderName()); given(token.getParameterName()).willReturn(this.token.getParameterName()); - given(this.tokenRepository.loadToken(this.request)).willReturn(token); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(token, false)); given(this.requestMatcher.matches(this.request)).willReturn(true); filter.doFilterInternal(this.request, this.response, this.filterChain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); @@ -341,13 +350,15 @@ public class CsrfFilterTests { @Test public void doFilterWhenRequestHandlerThenUsed() throws Exception { - CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); - given(requestHandler.handle(this.request, this.response)) + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); - this.filter = createCsrfFilter(requestHandler); + CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); + this.filter = createCsrfFilter(this.tokenRepository); + this.filter.setRequestHandler(requestHandler); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - verify(requestHandler).handle(eq(this.request), eq(this.response)); + verify(this.tokenRepository).loadDeferredToken(this.request, this.response); + verify(requestHandler).handle(eq(this.request), eq(this.response), any()); verify(this.filterChain).doFilter(this.request, this.response); } @@ -365,41 +376,20 @@ public class CsrfFilterTests { @Test public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet() throws ServletException, IOException { + CsrfFilter filter = createCsrfFilter(this.tokenRepository); String csrfAttrName = "_csrf"; - CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository); + CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler(); requestHandler.setCsrfRequestAttributeName(csrfAttrName); - this.filter = createCsrfFilter(requestHandler); - CsrfToken expectedCsrfToken = spy(this.token); - given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken); + filter.setRequestHandler(requestHandler); + CsrfToken expectedCsrfToken = mock(CsrfToken.class); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true)); - this.filter.doFilter(this.request, this.response, this.filterChain); + filter.doFilter(this.request, this.response, this.filterChain); verifyNoInteractions(expectedCsrfToken); CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken); } - private static final class TestDeferredCsrfToken implements DeferredCsrfToken { - - private final CsrfToken csrfToken; - - private final boolean isGenerated; - - private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) { - this.csrfToken = csrfToken; - this.isGenerated = isGenerated; - } - - @Override - public CsrfToken get() { - return this.csrfToken; - } - - @Override - public boolean isGenerated() { - return this.isGenerated; - } - - } - } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandlerTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandlerTests.java similarity index 72% rename from web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandlerTests.java rename to web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandlerTests.java index 22d01f7fdb..8e6db8eebb 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandlerTests.java @@ -18,29 +18,22 @@ package org.springframework.security.web.csrf; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.BDDMockito.given; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; /** - * Tests for {@link CsrfTokenRepositoryRequestHandler}. + * Tests for {@link CsrfTokenRequestAttributeHandler}. * * @author Steve Riesenberg * @since 5.8 */ -@ExtendWith(MockitoExtension.class) -public class CsrfTokenRepositoryRequestHandlerTests { - - @Mock - CsrfTokenRepository tokenRepository; +public class CsrfTokenRequestAttributeHandlerTests { private MockHttpServletRequest request; @@ -48,76 +41,73 @@ public class CsrfTokenRepositoryRequestHandlerTests { private CsrfToken token; - private CsrfTokenRepositoryRequestHandler handler; + private CsrfTokenRequestAttributeHandler handler; @BeforeEach public void setup() { this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); - this.handler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository); - } - - @Test - public void constructorWhenCsrfTokenRepositoryIsNullThenThrowsIllegalArgumentException() { - // @formatter:off - assertThatIllegalArgumentException() - .isThrownBy(() -> new CsrfTokenRepositoryRequestHandler(null)) - .withMessage("csrfTokenRepository cannot be null"); - // @formatter:on + this.handler = new CsrfTokenRequestAttributeHandler(); } @Test public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() { - // @formatter:off assertThatIllegalArgumentException() - .isThrownBy(() -> this.handler.handle(null, this.response)) + .isThrownBy(() -> this.handler.handle(null, this.response, () -> this.token)) .withMessage("request cannot be null"); - // @formatter:on } @Test public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() { // @formatter:off assertThatIllegalArgumentException() - .isThrownBy(() -> this.handler.handle(this.request, null)) + .isThrownBy(() -> this.handler.handle(this.request, null, () -> this.token)) .withMessage("response cannot be null"); // @formatter:on } + @Test + public void handleWhenCsrfTokenSupplierIsNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.handler.handle(this.request, this.response, null)) + .withMessage("deferredCsrfToken cannot be null"); + } + + @Test + public void handleWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() { + // @formatter:off + this.handler.setCsrfRequestAttributeName(null); + assertThatIllegalStateException() + .isThrownBy(() -> this.handler.handle(this.request, this.response, () -> null)) + .withMessage("csrfTokenSupplier returned null delegate"); + // @formatter:on + } + @Test public void handleWhenCsrfRequestAttributeSetThenUsed() { - given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); this.handler.setCsrfRequestAttributeName("_csrf"); - this.handler.handle(this.request, this.response); + this.handler.handle(this.request, this.response, () -> this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token); } @Test public void handleWhenValidParametersThenRequestAttributesSet() { - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); - this.handler.handle(this.request, this.response); + this.handler.handle(this.request, this.response, () -> this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); } @Test public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() { - // @formatter:off - assertThatIllegalArgumentException() - .isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token)) + assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token)) .withMessage("request cannot be null"); - // @formatter:on } @Test public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() { - // @formatter:off - assertThatIllegalArgumentException() - .isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null)) + assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null)) .withMessage("csrfToken cannot be null"); - // @formatter:on } @Test diff --git a/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java index 13d9fca65d..3c7c33bdba 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import org.springframework.mock.web.MockHttpServletResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; /** * @author Rob Winch @@ -85,6 +86,26 @@ public class HttpSessionCsrfTokenRepositoryTests { assertThat(this.repo.loadToken(this.request)).isNull(); } + @Test + public void loadDeferredTokenWhenDoesNotExistThenGeneratedAndSaved() { + DeferredCsrfToken deferredCsrfToken = this.repo.loadDeferredToken(this.request, this.response); + CsrfToken csrfToken = deferredCsrfToken.get(); + assertThat(csrfToken).isNotNull(); + assertThat(deferredCsrfToken.isGenerated()).isTrue(); + String attrName = this.request.getSession().getAttributeNames().nextElement(); + assertThatCsrfToken(this.request.getSession().getAttribute(attrName)).isEqualTo(csrfToken); + } + + @Test + public void loadDeferredTokenWhenExistsThenLoaded() { + CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def"); + this.repo.saveToken(tokenToSave, this.request, this.response); + DeferredCsrfToken deferredCsrfToken = this.repo.loadDeferredToken(this.request, this.response); + CsrfToken csrfToken = deferredCsrfToken.get(); + assertThatCsrfToken(csrfToken).isEqualTo(tokenToSave); + assertThat(deferredCsrfToken.isGenerated()).isFalse(); + } + @Test public void saveToken() { CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def"); diff --git a/web/src/test/java/org/springframework/security/web/csrf/TestDeferredCsrfToken.java b/web/src/test/java/org/springframework/security/web/csrf/TestDeferredCsrfToken.java new file mode 100644 index 0000000000..193b476d2b --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/csrf/TestDeferredCsrfToken.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + private final boolean isGenerated; + + TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) { + this.csrfToken = csrfToken; + this.isGenerated = isGenerated; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return this.isGenerated; + } + +}