SEC-3097: Change CsrfRequestPostProcessor to use TestCsrfTokenRepository

This ensures that when using a wrapped HttpServletRequest (i.e. Spring
Session) that the CSRF token test support still works.
This commit is contained in:
Rob Winch 2015-09-01 23:18:51 -05:00
parent ea94706319
commit 81e2778106
5 changed files with 114 additions and 68 deletions

View File

@ -316,6 +316,10 @@ public final class SecurityMockMvcRequestPostProcessors {
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
if(!(repository instanceof TestCsrfTokenRepository)) {
repository = new TestCsrfTokenRepository(repository);
WebTestUtils.setCsrfTokenRepository(request, repository);
}
CsrfToken token = repository.generateToken(request);
repository.saveToken(token, request, new MockHttpServletResponse());
String tokenValue = useInvalidToken ? "invalid" + token.getToken() : token
@ -352,6 +356,36 @@ public final class SecurityMockMvcRequestPostProcessors {
private CsrfRequestPostProcessor() {
}
/**
* Used to wrap the CsrfTokenRepository to provide support for testing
* when the request is wrapped (i.e. Spring Session is in use).
*/
static class TestCsrfTokenRepository implements
CsrfTokenRepository {
final static String ATTR_NAME = TestCsrfTokenRepository.class
.getName().concat(".TOKEN");
private final CsrfTokenRepository delegate;
private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
this.delegate = delegate;
}
public CsrfToken generateToken(HttpServletRequest request) {
return delegate.generateToken(request);
}
public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
request.setAttribute(ATTR_NAME, token);
}
public CsrfToken loadToken(HttpServletRequest request) {
return (CsrfToken) request.getAttribute(ATTR_NAME);
}
}
}
public static class DigestRequestPostProcessor implements RequestPostProcessor {

View File

@ -97,6 +97,22 @@ public abstract class WebTestUtils {
"tokenRepository");
}
/**
* Sets the {@link CsrfTokenRepository} for the specified
* {@link HttpServletRequest}.
*
* @param request the {@link HttpServletRequest} to obtain the
* {@link CsrfTokenRepository}
* @param repository the {@link CsrfTokenRepository} to set
*/
public static void setCsrfTokenRepository(HttpServletRequest request,
CsrfTokenRepository repository) {
CsrfFilter filter = findFilter(request, CsrfFilter.class);
if (filter != null) {
ReflectionTestUtils.setField(filter, "tokenRepository", repository);
}
}
@SuppressWarnings("unchecked")
private static <T extends Filter> T findFilter(HttpServletRequest request,
Class<T> filterClass) {

View File

@ -15,48 +15,28 @@
*/
package org.springframework.security.test.web.servlet.request;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.verify;
import static org.fest.assertions.Assertions.assertThat;
import static org.powermock.api.mockito.PowerMockito.spy;
import static org.powermock.api.mockito.PowerMockito.when;
import static org.powermock.api.mockito.PowerMockito.doReturn;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.test.web.support.WebTestUtils;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
import org.springframework.security.web.csrf.CsrfToken;
@RunWith(PowerMockRunner.class)
@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLoginTests.class })
public class SecurityMockMvcRequestBuildersFormLoginTests {
@Mock
private CsrfTokenRepository repository;
private DefaultCsrfToken token;
private MockServletContext servletContext;
@Before
public void setup() throws Exception {
token = new DefaultCsrfToken("header", "param", "token");
servletContext = new MockServletContext();
mockWebTestUtils();
}
@Test
public void defaults() throws Exception {
MockHttpServletRequest request = formLogin().buildRequest(servletContext);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("user");
assertThat(request.getParameter("password")).isEqualTo("password");
@ -64,8 +44,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/login");
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
any(HttpServletResponse.class));
assertThat(request.getParameter("_csrf")).isNotNull();
}
@Test
@ -73,20 +52,13 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
MockHttpServletRequest request = formLogin("/login").user("username", "admin")
.password("password", "secret").buildRequest(servletContext);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("admin");
assertThat(request.getParameter("password")).isEqualTo("secret");
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/login");
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
any(HttpServletResponse.class));
}
private void mockWebTestUtils() throws Exception {
spy(WebTestUtils.class);
doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository",
any(HttpServletRequest.class));
when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token);
}
}

View File

@ -15,75 +15,47 @@
*/
package org.springframework.security.test.web.servlet.request;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.verify;
import static org.fest.assertions.Assertions.assertThat;
import static org.powermock.api.mockito.PowerMockito.spy;
import static org.powermock.api.mockito.PowerMockito.when;
import static org.powermock.api.mockito.PowerMockito.doReturn;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.test.web.support.WebTestUtils;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
import org.springframework.security.web.csrf.CsrfToken;
@RunWith(PowerMockRunner.class)
@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLogoutTests.class })
public class SecurityMockMvcRequestBuildersFormLogoutTests {
@Mock
private CsrfTokenRepository repository;
private DefaultCsrfToken token;
private MockServletContext servletContext;
@Before
public void setup() {
token = new DefaultCsrfToken("header", "param", "token");
servletContext = new MockServletContext();
}
@Test
public void defaults() throws Exception {
mockWebTestUtils();
MockHttpServletRequest request = logout().buildRequest(servletContext);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/logout");
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
any(HttpServletResponse.class));
}
@Test
public void custom() throws Exception {
mockWebTestUtils();
MockHttpServletRequest request = logout("/admin/logout").buildRequest(
servletContext);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
any(HttpServletResponse.class));
}
private void mockWebTestUtils() throws Exception {
spy(WebTestUtils.class);
doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository",
any(HttpServletRequest.class));
when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token);
}
}

View File

@ -21,12 +21,22 @@ import static org.springframework.security.test.web.servlet.setup.SecurityMockMv
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import java.io.IOException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
@ -40,6 +50,7 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.filter.OncePerRequestFilter;
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration
@ -86,6 +97,20 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
.andExpect(csrfAsHeader());
}
// SEC-3097
@Test
public void csrfWithWrappedRequest() throws Exception {
mockMvc = MockMvcBuilders
.webAppContextSetup(wac)
.addFilter(new SessionRepositoryFilter())
.apply(springSecurity())
.build();
mockMvc.perform(post("/").with(csrf()))
.andExpect(status().is2xxSuccessful())
.andExpect(csrfAsParam());
}
public static ResultMatcher csrfAsParam() {
return new CsrfParamResultMatcher();
}
@ -112,6 +137,33 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
}
}
static class SessionRepositoryFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
filterChain.doFilter(new SessionRequestWrapper(request) , response);
}
static class SessionRequestWrapper extends HttpServletRequestWrapper {
HttpSession session = new MockHttpSession();
public SessionRequestWrapper(HttpServletRequest request) {
super(request);
}
@Override
public HttpSession getSession(boolean create) {
return session;
}
@Override
public HttpSession getSession() {
return session;
}
}
}
@EnableWebSecurity
static class Config extends WebSecurityConfigurerAdapter {
@Override