diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java index 50283689cf..3c47fdf107 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java @@ -17,78 +17,118 @@ package org.springframework.security.test.web.servlet.request; import static org.fest.assertions.Assertions.assertThat; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; -import static org.powermock.api.mockito.PowerMockito.*; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest; -import org.powermock.modules.junit4.PowerMockRunner; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.security.test.web.support.WebTestUtils; -import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.WebApplicationContext; -@RunWith(PowerMockRunner.class) -@PrepareOnlyThisForTest(WebTestUtils.class) +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@WebAppConfiguration public class SecurityMockMvcRequestPostProcessorsCsrfTests { - @Mock - private CsrfTokenRepository repository; - private DefaultCsrfToken token; + @Autowired + WebApplicationContext wac; - private MockHttpServletRequest request; + MockMvc mockMvc; @Before public void setup() { - token = new DefaultCsrfToken("header", "param", "token"); - request = new MockHttpServletRequest(); - mockWebTestUtils(); + mockMvc = MockMvcBuilders + .webAppContextSetup(wac) + .apply(springSecurity()) + .build(); } @Test - public void csrfWithParam() { - MockHttpServletRequest postProcessedRequest = csrf().postProcessRequest(request); - - assertThat(postProcessedRequest.getParameter(token.getParameterName())) - .isEqualTo(token.getToken()); - assertThat(postProcessedRequest.getHeader(token.getHeaderName())).isNull(); + public void csrfWithParam() throws Exception { + mockMvc.perform(post("/").with(csrf())) + .andExpect(status().is2xxSuccessful()) + .andExpect(csrfAsParam()); } @Test - public void csrfWithHeader() { - MockHttpServletRequest postProcessedRequest = csrf().asHeader() - .postProcessRequest(request); - - assertThat(postProcessedRequest.getParameter(token.getParameterName())).isNull(); - assertThat(postProcessedRequest.getHeader(token.getHeaderName())).isEqualTo( - token.getToken()); + public void csrfWithHeader() throws Exception { + mockMvc.perform(post("/").with(csrf().asHeader())) + .andExpect(status().is2xxSuccessful()) + .andExpect(csrfAsHeader()); } @Test - public void csrfWithInvalidParam() { - MockHttpServletRequest postProcessedRequest = csrf().useInvalidToken() - .postProcessRequest(request); - - assertThat(postProcessedRequest.getParameter(token.getParameterName())) - .isNotEmpty().isNotEqualTo(token.getToken()); - assertThat(postProcessedRequest.getHeader(token.getHeaderName())).isNull(); + public void csrfWithInvalidParam() throws Exception { + mockMvc.perform(post("/").with(csrf().useInvalidToken())) + .andExpect(status().isForbidden()) + .andExpect(csrfAsParam()); } @Test - public void csrfWithInvalidHeader() { - MockHttpServletRequest postProcessedRequest = csrf().asHeader().useInvalidToken() - .postProcessRequest(request); - - assertThat(postProcessedRequest.getParameter(token.getParameterName())).isNull(); - assertThat(postProcessedRequest.getHeader(token.getHeaderName())).isNotEmpty() - .isNotEqualTo(token.getToken()); + public void csrfWithInvalidHeader() throws Exception { + mockMvc.perform(post("/").with(csrf().asHeader().useInvalidToken())) + .andExpect(status().isForbidden()) + .andExpect(csrfAsHeader()); } - private void mockWebTestUtils() { - spy(WebTestUtils.class); - when(WebTestUtils.getCsrfTokenRepository(request)).thenReturn(repository); - when(repository.loadToken(request)).thenReturn(token); - when(repository.generateToken(request)).thenReturn(token); + public static ResultMatcher csrfAsParam() { + return new CsrfParamResultMatcher(); + } + + static class CsrfParamResultMatcher implements ResultMatcher { + + public void match(MvcResult result) throws Exception { + MockHttpServletRequest request = result.getRequest(); + assertThat(request.getParameter("_csrf")).isNotNull(); + assertThat(request.getHeader("X-CSRF-TOKEN")).isNull(); + } + } + + public static ResultMatcher csrfAsHeader() { + return new CsrfHeaderResultMatcher(); + } + + static class CsrfHeaderResultMatcher implements ResultMatcher { + + public void match(MvcResult result) throws Exception { + MockHttpServletRequest request = result.getRequest(); + assertThat(request.getParameter("_csrf")).isNull(); + assertThat(request.getHeader("X-CSRF-TOKEN")).isNotNull(); + } + } + + @EnableWebSecurity + static class Config extends WebSecurityConfigurerAdapter { + @Override + protected void configure(HttpSecurity http) throws Exception { + } + + @Bean + public TheController controller() { + return new TheController(); + } + + @RestController + static class TheController { + @RequestMapping("/") + String index() { + return "Hi"; + } + } } }