diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 2249fa49df..8c7cad5ff5 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -316,6 +316,10 @@ public final class SecurityMockMvcRequestPostProcessors { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); + if(!(repository instanceof TestCsrfTokenRepository)) { + repository = new TestCsrfTokenRepository(repository); + WebTestUtils.setCsrfTokenRepository(request, repository); + } CsrfToken token = repository.generateToken(request); repository.saveToken(token, request, new MockHttpServletResponse()); String tokenValue = useInvalidToken ? "invalid" + token.getToken() : token @@ -352,6 +356,36 @@ public final class SecurityMockMvcRequestPostProcessors { private CsrfRequestPostProcessor() { } + + + + /** + * Used to wrap the CsrfTokenRepository to provide support for testing + * when the request is wrapped (i.e. Spring Session is in use). + */ + static class TestCsrfTokenRepository implements + CsrfTokenRepository { + final static String ATTR_NAME = TestCsrfTokenRepository.class + .getName().concat(".TOKEN"); + + private final CsrfTokenRepository delegate; + + private TestCsrfTokenRepository(CsrfTokenRepository delegate) { + this.delegate = delegate; + } + + public CsrfToken generateToken(HttpServletRequest request) { + return delegate.generateToken(request); + } + + public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { + request.setAttribute(ATTR_NAME, token); + } + + public CsrfToken loadToken(HttpServletRequest request) { + return (CsrfToken) request.getAttribute(ATTR_NAME); + } + } } public static class DigestRequestPostProcessor implements RequestPostProcessor { diff --git a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java index 360d51d005..d9be1dcf58 100644 --- a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java +++ b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java @@ -97,6 +97,22 @@ public abstract class WebTestUtils { "tokenRepository"); } + /** + * Sets the {@link CsrfTokenRepository} for the specified + * {@link HttpServletRequest}. + * + * @param request the {@link HttpServletRequest} to obtain the + * {@link CsrfTokenRepository} + * @param repository the {@link CsrfTokenRepository} to set + */ + public static void setCsrfTokenRepository(HttpServletRequest request, + CsrfTokenRepository repository) { + CsrfFilter filter = findFilter(request, CsrfFilter.class); + if (filter != null) { + ReflectionTestUtils.setField(filter, "tokenRepository", repository); + } + } + @SuppressWarnings("unchecked") private static T findFilter(HttpServletRequest request, Class filterClass) { diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java index 21b554791d..deedb9d06f 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java @@ -15,48 +15,28 @@ */ package org.springframework.security.test.web.servlet.request; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.verify; import static org.fest.assertions.Assertions.assertThat; -import static org.powermock.api.mockito.PowerMockito.spy; -import static org.powermock.api.mockito.PowerMockito.when; -import static org.powermock.api.mockito.PowerMockito.doReturn; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockServletContext; -import org.springframework.security.test.web.support.WebTestUtils; -import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor; +import org.springframework.security.web.csrf.CsrfToken; -@RunWith(PowerMockRunner.class) -@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLoginTests.class }) public class SecurityMockMvcRequestBuildersFormLoginTests { - @Mock - private CsrfTokenRepository repository; - private DefaultCsrfToken token; private MockServletContext servletContext; @Before public void setup() throws Exception { - token = new DefaultCsrfToken("header", "param", "token"); servletContext = new MockServletContext(); - mockWebTestUtils(); } @Test public void defaults() throws Exception { MockHttpServletRequest request = formLogin().buildRequest(servletContext); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("password")).isEqualTo("password"); @@ -64,8 +44,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { assertThat(request.getParameter(token.getParameterName())).isEqualTo( token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/login"); - verify(repository).saveToken(eq(token), any(HttpServletRequest.class), - any(HttpServletResponse.class)); + assertThat(request.getParameter("_csrf")).isNotNull(); } @Test @@ -73,20 +52,13 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { MockHttpServletRequest request = formLogin("/login").user("username", "admin") .password("password", "secret").buildRequest(servletContext); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo( token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/login"); - verify(repository).saveToken(eq(token), any(HttpServletRequest.class), - any(HttpServletResponse.class)); - } - - private void mockWebTestUtils() throws Exception { - spy(WebTestUtils.class); - doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository", - any(HttpServletRequest.class)); - when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token); } } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java index 7fc1fe7d49..5652fd505e 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java @@ -15,75 +15,47 @@ */ package org.springframework.security.test.web.servlet.request; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.verify; import static org.fest.assertions.Assertions.assertThat; -import static org.powermock.api.mockito.PowerMockito.spy; -import static org.powermock.api.mockito.PowerMockito.when; -import static org.powermock.api.mockito.PowerMockito.doReturn; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockServletContext; -import org.springframework.security.test.web.support.WebTestUtils; -import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor; +import org.springframework.security.web.csrf.CsrfToken; -@RunWith(PowerMockRunner.class) -@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLogoutTests.class }) public class SecurityMockMvcRequestBuildersFormLogoutTests { - @Mock - private CsrfTokenRepository repository; - private DefaultCsrfToken token; private MockServletContext servletContext; @Before public void setup() { - token = new DefaultCsrfToken("header", "param", "token"); servletContext = new MockServletContext(); } @Test public void defaults() throws Exception { - mockWebTestUtils(); MockHttpServletRequest request = logout().buildRequest(servletContext); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo( token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/logout"); - verify(repository).saveToken(eq(token), any(HttpServletRequest.class), - any(HttpServletResponse.class)); } @Test public void custom() throws Exception { - mockWebTestUtils(); MockHttpServletRequest request = logout("/admin/logout").buildRequest( servletContext); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); + assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo( token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/admin/logout"); - verify(repository).saveToken(eq(token), any(HttpServletRequest.class), - any(HttpServletResponse.class)); } - private void mockWebTestUtils() throws Exception { - spy(WebTestUtils.class); - doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository", - any(HttpServletRequest.class)); - when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token); - } } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java index 3c47fdf107..f3aa392634 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java @@ -21,12 +21,22 @@ import static org.springframework.security.test.web.servlet.setup.SecurityMockMv import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpSession; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; @@ -40,6 +50,7 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.filter.OncePerRequestFilter; @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration @@ -86,6 +97,20 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests { .andExpect(csrfAsHeader()); } + // SEC-3097 + @Test + public void csrfWithWrappedRequest() throws Exception { + mockMvc = MockMvcBuilders + .webAppContextSetup(wac) + .addFilter(new SessionRepositoryFilter()) + .apply(springSecurity()) + .build(); + + mockMvc.perform(post("/").with(csrf())) + .andExpect(status().is2xxSuccessful()) + .andExpect(csrfAsParam()); + } + public static ResultMatcher csrfAsParam() { return new CsrfParamResultMatcher(); } @@ -112,6 +137,33 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests { } } + static class SessionRepositoryFilter extends OncePerRequestFilter { + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + filterChain.doFilter(new SessionRequestWrapper(request) , response); + } + + static class SessionRequestWrapper extends HttpServletRequestWrapper { + HttpSession session = new MockHttpSession(); + + public SessionRequestWrapper(HttpServletRequest request) { + super(request); + } + + @Override + public HttpSession getSession(boolean create) { + return session; + } + + @Override + public HttpSession getSession() { + return session; + } + } + } + @EnableWebSecurity static class Config extends WebSecurityConfigurerAdapter { @Override