diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 86bcef4779..06e138f4db 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -318,13 +318,13 @@ public final class SecurityMockMvcRequestPostProcessors { */ @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); if (!(repository instanceof TestCsrfTokenRepository)) { repository = new TestCsrfTokenRepository( new HttpSessionCsrfTokenRepository()); WebTestUtils.setCsrfTokenRepository(request, repository); } + TestCsrfTokenRepository.enable(request); CsrfToken token = repository.generateToken(request); repository.saveToken(token, request, new MockHttpServletResponse()); String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() @@ -367,9 +367,12 @@ public final class SecurityMockMvcRequestPostProcessors { * request is wrapped (i.e. Spring Session is in use). */ static class TestCsrfTokenRepository implements CsrfTokenRepository { - final static String ATTR_NAME = TestCsrfTokenRepository.class.getName() + final static String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName() .concat(".TOKEN"); + final static String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class + .getName().concat(".ENABLED"); + private final CsrfTokenRepository delegate; private TestCsrfTokenRepository(CsrfTokenRepository delegate) { @@ -384,12 +387,30 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { - request.setAttribute(ATTR_NAME, token); + if (isEnabled(request)) { + request.setAttribute(TOKEN_ATTR_NAME, token); + } + else { + this.delegate.saveToken(token, request, response); + } } @Override public CsrfToken loadToken(HttpServletRequest request) { - return (CsrfToken) request.getAttribute(ATTR_NAME); + if (isEnabled(request)) { + return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME); + } + else { + return this.delegate.loadToken(request); + } + } + + public static void enable(HttpServletRequest request) { + request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE); + } + + public boolean isEnabled(HttpServletRequest request) { + return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME)); } } } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java index bddfa3d218..24e3a58316 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java @@ -39,7 +39,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { public void defaults() throws Exception { MockHttpServletRequest request = formLogin().buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("password")).isEqualTo("password"); @@ -56,7 +56,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { .password("password", "secret").buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java index 16b7ebec06..4d5314bb14 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java @@ -37,7 +37,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { public void defaults() throws Exception { MockHttpServletRequest request = logout().buildRequest(servletContext); - CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo( @@ -50,7 +50,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { MockHttpServletRequest request = logout("/admin/logout").buildRequest( servletContext); - CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo( diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java index a405c6ddbb..f79d73b90d 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java @@ -31,18 +31,22 @@ import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessorsCsrfTests.Config.TheController; import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.context.web.WebAppConfiguration; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @@ -143,6 +147,25 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests { // @formatter:on } + // gh-4016 + @Test + public void csrfWhenUsedThenDoesNotImpactOriginalRepository() throws Exception { + // @formatter:off + this.mockMvc.perform(post("/").with(csrf())); + + MockHttpServletRequest request = new MockHttpServletRequest(); + HttpSessionCsrfTokenRepository repo = new HttpSessionCsrfTokenRepository(); + CsrfToken token = repo.generateToken(request); + repo.saveToken(token, request, new MockHttpServletResponse()); + + MockHttpServletRequestBuilder requestWithCsrf = post("/") + .param(token.getParameterName(), token.getToken()) + .session((MockHttpSession)request.getSession()); + this.mockMvc.perform(requestWithCsrf) + .andExpect(status().isOk()); + // @formatter:on + } + public static ResultMatcher csrfAsParam() { return new CsrfParamResultMatcher(); }