SEC-3097: Change CsrfRequestPostProcessor to use TestCsrfTokenRepository

This ensures that when using a wrapped HttpServletRequest (i.e. Spring
Session) that the CSRF token test support still works.
This commit is contained in:
Rob Winch 2015-09-01 23:18:51 -05:00
parent ea94706319
commit 81e2778106
5 changed files with 114 additions and 68 deletions

View File

@ -316,6 +316,10 @@ public final class SecurityMockMvcRequestPostProcessors {
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
if(!(repository instanceof TestCsrfTokenRepository)) {
repository = new TestCsrfTokenRepository(repository);
WebTestUtils.setCsrfTokenRepository(request, repository);
}
CsrfToken token = repository.generateToken(request); CsrfToken token = repository.generateToken(request);
repository.saveToken(token, request, new MockHttpServletResponse()); repository.saveToken(token, request, new MockHttpServletResponse());
String tokenValue = useInvalidToken ? "invalid" + token.getToken() : token String tokenValue = useInvalidToken ? "invalid" + token.getToken() : token
@ -352,6 +356,36 @@ public final class SecurityMockMvcRequestPostProcessors {
private CsrfRequestPostProcessor() { private CsrfRequestPostProcessor() {
} }
/**
* Used to wrap the CsrfTokenRepository to provide support for testing
* when the request is wrapped (i.e. Spring Session is in use).
*/
static class TestCsrfTokenRepository implements
CsrfTokenRepository {
final static String ATTR_NAME = TestCsrfTokenRepository.class
.getName().concat(".TOKEN");
private final CsrfTokenRepository delegate;
private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
this.delegate = delegate;
}
public CsrfToken generateToken(HttpServletRequest request) {
return delegate.generateToken(request);
}
public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
request.setAttribute(ATTR_NAME, token);
}
public CsrfToken loadToken(HttpServletRequest request) {
return (CsrfToken) request.getAttribute(ATTR_NAME);
}
}
} }
public static class DigestRequestPostProcessor implements RequestPostProcessor { public static class DigestRequestPostProcessor implements RequestPostProcessor {

View File

@ -97,6 +97,22 @@ public abstract class WebTestUtils {
"tokenRepository"); "tokenRepository");
} }
/**
* Sets the {@link CsrfTokenRepository} for the specified
* {@link HttpServletRequest}.
*
* @param request the {@link HttpServletRequest} to obtain the
* {@link CsrfTokenRepository}
* @param repository the {@link CsrfTokenRepository} to set
*/
public static void setCsrfTokenRepository(HttpServletRequest request,
CsrfTokenRepository repository) {
CsrfFilter filter = findFilter(request, CsrfFilter.class);
if (filter != null) {
ReflectionTestUtils.setField(filter, "tokenRepository", repository);
}
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static <T extends Filter> T findFilter(HttpServletRequest request, private static <T extends Filter> T findFilter(HttpServletRequest request,
Class<T> filterClass) { Class<T> filterClass) {

View File

@ -15,48 +15,28 @@
*/ */
package org.springframework.security.test.web.servlet.request; package org.springframework.security.test.web.servlet.request;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.verify;
import static org.fest.assertions.Assertions.assertThat; import static org.fest.assertions.Assertions.assertThat;
import static org.powermock.api.mockito.PowerMockito.spy;
import static org.powermock.api.mockito.PowerMockito.when;
import static org.powermock.api.mockito.PowerMockito.doReturn;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext; import org.springframework.mock.web.MockServletContext;
import org.springframework.security.test.web.support.WebTestUtils; import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
@RunWith(PowerMockRunner.class)
@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLoginTests.class })
public class SecurityMockMvcRequestBuildersFormLoginTests { public class SecurityMockMvcRequestBuildersFormLoginTests {
@Mock
private CsrfTokenRepository repository;
private DefaultCsrfToken token;
private MockServletContext servletContext; private MockServletContext servletContext;
@Before @Before
public void setup() throws Exception { public void setup() throws Exception {
token = new DefaultCsrfToken("header", "param", "token");
servletContext = new MockServletContext(); servletContext = new MockServletContext();
mockWebTestUtils();
} }
@Test @Test
public void defaults() throws Exception { public void defaults() throws Exception {
MockHttpServletRequest request = formLogin().buildRequest(servletContext); MockHttpServletRequest request = formLogin().buildRequest(servletContext);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("username")).isEqualTo("user");
assertThat(request.getParameter("password")).isEqualTo("password"); assertThat(request.getParameter("password")).isEqualTo("password");
@ -64,8 +44,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
assertThat(request.getParameter(token.getParameterName())).isEqualTo( assertThat(request.getParameter(token.getParameterName())).isEqualTo(
token.getToken()); token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/login"); assertThat(request.getRequestURI()).isEqualTo("/login");
verify(repository).saveToken(eq(token), any(HttpServletRequest.class), assertThat(request.getParameter("_csrf")).isNotNull();
any(HttpServletResponse.class));
} }
@Test @Test
@ -73,20 +52,13 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
MockHttpServletRequest request = formLogin("/login").user("username", "admin") MockHttpServletRequest request = formLogin("/login").user("username", "admin")
.password("password", "secret").buildRequest(servletContext); .password("password", "secret").buildRequest(servletContext);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("username")).isEqualTo("admin");
assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getParameter("password")).isEqualTo("secret");
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo( assertThat(request.getParameter(token.getParameterName())).isEqualTo(
token.getToken()); token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/login"); assertThat(request.getRequestURI()).isEqualTo("/login");
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
any(HttpServletResponse.class));
}
private void mockWebTestUtils() throws Exception {
spy(WebTestUtils.class);
doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository",
any(HttpServletRequest.class));
when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token);
} }
} }

View File

@ -15,75 +15,47 @@
*/ */
package org.springframework.security.test.web.servlet.request; package org.springframework.security.test.web.servlet.request;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.verify;
import static org.fest.assertions.Assertions.assertThat; import static org.fest.assertions.Assertions.assertThat;
import static org.powermock.api.mockito.PowerMockito.spy;
import static org.powermock.api.mockito.PowerMockito.when;
import static org.powermock.api.mockito.PowerMockito.doReturn;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext; import org.springframework.mock.web.MockServletContext;
import org.springframework.security.test.web.support.WebTestUtils; import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
@RunWith(PowerMockRunner.class)
@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLogoutTests.class })
public class SecurityMockMvcRequestBuildersFormLogoutTests { public class SecurityMockMvcRequestBuildersFormLogoutTests {
@Mock
private CsrfTokenRepository repository;
private DefaultCsrfToken token;
private MockServletContext servletContext; private MockServletContext servletContext;
@Before @Before
public void setup() { public void setup() {
token = new DefaultCsrfToken("header", "param", "token");
servletContext = new MockServletContext(); servletContext = new MockServletContext();
} }
@Test @Test
public void defaults() throws Exception { public void defaults() throws Exception {
mockWebTestUtils();
MockHttpServletRequest request = logout().buildRequest(servletContext); MockHttpServletRequest request = logout().buildRequest(servletContext);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo( assertThat(request.getParameter(token.getParameterName())).isEqualTo(
token.getToken()); token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/logout"); assertThat(request.getRequestURI()).isEqualTo("/logout");
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
any(HttpServletResponse.class));
} }
@Test @Test
public void custom() throws Exception { public void custom() throws Exception {
mockWebTestUtils();
MockHttpServletRequest request = logout("/admin/logout").buildRequest( MockHttpServletRequest request = logout("/admin/logout").buildRequest(
servletContext); servletContext);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo( assertThat(request.getParameter(token.getParameterName())).isEqualTo(
token.getToken()); token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/admin/logout"); assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
any(HttpServletResponse.class));
} }
private void mockWebTestUtils() throws Exception {
spy(WebTestUtils.class);
doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository",
any(HttpServletRequest.class));
when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token);
}
} }

View File

@ -21,12 +21,22 @@ import static org.springframework.security.test.web.servlet.setup.SecurityMockMv
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import java.io.IOException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
@ -40,6 +50,7 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.filter.OncePerRequestFilter;
@RunWith(SpringJUnit4ClassRunner.class) @RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration @ContextConfiguration
@ -86,6 +97,20 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
.andExpect(csrfAsHeader()); .andExpect(csrfAsHeader());
} }
// SEC-3097
@Test
public void csrfWithWrappedRequest() throws Exception {
mockMvc = MockMvcBuilders
.webAppContextSetup(wac)
.addFilter(new SessionRepositoryFilter())
.apply(springSecurity())
.build();
mockMvc.perform(post("/").with(csrf()))
.andExpect(status().is2xxSuccessful())
.andExpect(csrfAsParam());
}
public static ResultMatcher csrfAsParam() { public static ResultMatcher csrfAsParam() {
return new CsrfParamResultMatcher(); return new CsrfParamResultMatcher();
} }
@ -112,6 +137,33 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
} }
} }
static class SessionRepositoryFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
filterChain.doFilter(new SessionRequestWrapper(request) , response);
}
static class SessionRequestWrapper extends HttpServletRequestWrapper {
HttpSession session = new MockHttpSession();
public SessionRequestWrapper(HttpServletRequest request) {
super(request);
}
@Override
public HttpSession getSession(boolean create) {
return session;
}
@Override
public HttpSession getSession() {
return session;
}
}
}
@EnableWebSecurity @EnableWebSecurity
static class Config extends WebSecurityConfigurerAdapter { static class Config extends WebSecurityConfigurerAdapter {
@Override @Override