From a93fb1e0e7b81a8bc2a09e92f0312b66bcc3c386 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 9 Aug 2016 15:00:06 -0500 Subject: [PATCH] Fix csrf() when used then not used Previously if csrf() was used and subsequently not used, the TestCsrfTokenRepository was still used. This makes it difficult to test the actual CsrfTokenRepository implementation. Now the TestCsrfTokenRepository is only used if explicitly enabled. Fixes gh-4016 --- .../SecurityMockMvcRequestPostProcessors.java | 29 ++++++++++++++++--- ...yMockMvcRequestBuildersFormLoginTests.java | 4 +-- ...MockMvcRequestBuildersFormLogoutTests.java | 4 +-- ...MockMvcRequestPostProcessorsCsrfTests.java | 23 +++++++++++++++ 4 files changed, 52 insertions(+), 8 deletions(-) diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 86bcef4779..06e138f4db 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -318,13 +318,13 @@ public final class SecurityMockMvcRequestPostProcessors { */ @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); if (!(repository instanceof TestCsrfTokenRepository)) { repository = new TestCsrfTokenRepository( new HttpSessionCsrfTokenRepository()); WebTestUtils.setCsrfTokenRepository(request, repository); } + TestCsrfTokenRepository.enable(request); CsrfToken token = repository.generateToken(request); repository.saveToken(token, request, new MockHttpServletResponse()); String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() @@ -367,9 +367,12 @@ public final class SecurityMockMvcRequestPostProcessors { * request is wrapped (i.e. Spring Session is in use). */ static class TestCsrfTokenRepository implements CsrfTokenRepository { - final static String ATTR_NAME = TestCsrfTokenRepository.class.getName() + final static String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName() .concat(".TOKEN"); + final static String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class + .getName().concat(".ENABLED"); + private final CsrfTokenRepository delegate; private TestCsrfTokenRepository(CsrfTokenRepository delegate) { @@ -384,12 +387,30 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { - request.setAttribute(ATTR_NAME, token); + if (isEnabled(request)) { + request.setAttribute(TOKEN_ATTR_NAME, token); + } + else { + this.delegate.saveToken(token, request, response); + } } @Override public CsrfToken loadToken(HttpServletRequest request) { - return (CsrfToken) request.getAttribute(ATTR_NAME); + if (isEnabled(request)) { + return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME); + } + else { + return this.delegate.loadToken(request); + } + } + + public static void enable(HttpServletRequest request) { + request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE); + } + + public boolean isEnabled(HttpServletRequest request) { + return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME)); } } } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java index bddfa3d218..24e3a58316 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java @@ -39,7 +39,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { public void defaults() throws Exception { MockHttpServletRequest request = formLogin().buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("password")).isEqualTo("password"); @@ -56,7 +56,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { .password("password", "secret").buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java index 16b7ebec06..4d5314bb14 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java @@ -37,7 +37,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { public void defaults() throws Exception { MockHttpServletRequest request = logout().buildRequest(servletContext); - CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo( @@ -50,7 +50,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { MockHttpServletRequest request = logout("/admin/logout").buildRequest( servletContext); - CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo( diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java index a405c6ddbb..f79d73b90d 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java @@ -31,18 +31,22 @@ import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessorsCsrfTests.Config.TheController; import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.context.web.WebAppConfiguration; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @@ -143,6 +147,25 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests { // @formatter:on } + // gh-4016 + @Test + public void csrfWhenUsedThenDoesNotImpactOriginalRepository() throws Exception { + // @formatter:off + this.mockMvc.perform(post("/").with(csrf())); + + MockHttpServletRequest request = new MockHttpServletRequest(); + HttpSessionCsrfTokenRepository repo = new HttpSessionCsrfTokenRepository(); + CsrfToken token = repo.generateToken(request); + repo.saveToken(token, request, new MockHttpServletResponse()); + + MockHttpServletRequestBuilder requestWithCsrf = post("/") + .param(token.getParameterName(), token.getToken()) + .session((MockHttpSession)request.getSession()); + this.mockMvc.perform(requestWithCsrf) + .andExpect(status().isOk()); + // @formatter:on + } + public static ResultMatcher csrfAsParam() { return new CsrfParamResultMatcher(); }