SEC-3097: Change CsrfRequestPostProcessor to use TestCsrfTokenRepository
This ensures that when using a wrapped HttpServletRequest (i.e. Spring Session) that the CSRF token test support still works.
This commit is contained in:
parent
ea94706319
commit
81e2778106
|
@ -316,6 +316,10 @@ public final class SecurityMockMvcRequestPostProcessors {
|
||||||
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
|
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
|
||||||
|
|
||||||
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
|
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
|
||||||
|
if(!(repository instanceof TestCsrfTokenRepository)) {
|
||||||
|
repository = new TestCsrfTokenRepository(repository);
|
||||||
|
WebTestUtils.setCsrfTokenRepository(request, repository);
|
||||||
|
}
|
||||||
CsrfToken token = repository.generateToken(request);
|
CsrfToken token = repository.generateToken(request);
|
||||||
repository.saveToken(token, request, new MockHttpServletResponse());
|
repository.saveToken(token, request, new MockHttpServletResponse());
|
||||||
String tokenValue = useInvalidToken ? "invalid" + token.getToken() : token
|
String tokenValue = useInvalidToken ? "invalid" + token.getToken() : token
|
||||||
|
@ -352,6 +356,36 @@ public final class SecurityMockMvcRequestPostProcessors {
|
||||||
|
|
||||||
private CsrfRequestPostProcessor() {
|
private CsrfRequestPostProcessor() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used to wrap the CsrfTokenRepository to provide support for testing
|
||||||
|
* when the request is wrapped (i.e. Spring Session is in use).
|
||||||
|
*/
|
||||||
|
static class TestCsrfTokenRepository implements
|
||||||
|
CsrfTokenRepository {
|
||||||
|
final static String ATTR_NAME = TestCsrfTokenRepository.class
|
||||||
|
.getName().concat(".TOKEN");
|
||||||
|
|
||||||
|
private final CsrfTokenRepository delegate;
|
||||||
|
|
||||||
|
private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
|
||||||
|
this.delegate = delegate;
|
||||||
|
}
|
||||||
|
|
||||||
|
public CsrfToken generateToken(HttpServletRequest request) {
|
||||||
|
return delegate.generateToken(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
|
||||||
|
request.setAttribute(ATTR_NAME, token);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CsrfToken loadToken(HttpServletRequest request) {
|
||||||
|
return (CsrfToken) request.getAttribute(ATTR_NAME);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class DigestRequestPostProcessor implements RequestPostProcessor {
|
public static class DigestRequestPostProcessor implements RequestPostProcessor {
|
||||||
|
|
|
@ -97,6 +97,22 @@ public abstract class WebTestUtils {
|
||||||
"tokenRepository");
|
"tokenRepository");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link CsrfTokenRepository} for the specified
|
||||||
|
* {@link HttpServletRequest}.
|
||||||
|
*
|
||||||
|
* @param request the {@link HttpServletRequest} to obtain the
|
||||||
|
* {@link CsrfTokenRepository}
|
||||||
|
* @param repository the {@link CsrfTokenRepository} to set
|
||||||
|
*/
|
||||||
|
public static void setCsrfTokenRepository(HttpServletRequest request,
|
||||||
|
CsrfTokenRepository repository) {
|
||||||
|
CsrfFilter filter = findFilter(request, CsrfFilter.class);
|
||||||
|
if (filter != null) {
|
||||||
|
ReflectionTestUtils.setField(filter, "tokenRepository", repository);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static <T extends Filter> T findFilter(HttpServletRequest request,
|
private static <T extends Filter> T findFilter(HttpServletRequest request,
|
||||||
Class<T> filterClass) {
|
Class<T> filterClass) {
|
||||||
|
|
|
@ -15,48 +15,28 @@
|
||||||
*/
|
*/
|
||||||
package org.springframework.security.test.web.servlet.request;
|
package org.springframework.security.test.web.servlet.request;
|
||||||
|
|
||||||
import static org.mockito.Matchers.any;
|
|
||||||
import static org.mockito.Matchers.eq;
|
|
||||||
import static org.mockito.Mockito.verify;
|
|
||||||
import static org.fest.assertions.Assertions.assertThat;
|
import static org.fest.assertions.Assertions.assertThat;
|
||||||
import static org.powermock.api.mockito.PowerMockito.spy;
|
|
||||||
import static org.powermock.api.mockito.PowerMockito.when;
|
|
||||||
import static org.powermock.api.mockito.PowerMockito.doReturn;
|
|
||||||
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
|
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
|
||||||
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
|
||||||
import javax.servlet.http.HttpServletResponse;
|
|
||||||
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
|
||||||
import org.mockito.Mock;
|
|
||||||
import org.powermock.core.classloader.annotations.PrepareForTest;
|
|
||||||
import org.powermock.modules.junit4.PowerMockRunner;
|
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.mock.web.MockServletContext;
|
import org.springframework.mock.web.MockServletContext;
|
||||||
import org.springframework.security.test.web.support.WebTestUtils;
|
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
|
||||||
import org.springframework.security.web.csrf.CsrfTokenRepository;
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
|
||||||
|
|
||||||
@RunWith(PowerMockRunner.class)
|
|
||||||
@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLoginTests.class })
|
|
||||||
public class SecurityMockMvcRequestBuildersFormLoginTests {
|
public class SecurityMockMvcRequestBuildersFormLoginTests {
|
||||||
@Mock
|
|
||||||
private CsrfTokenRepository repository;
|
|
||||||
private DefaultCsrfToken token;
|
|
||||||
private MockServletContext servletContext;
|
private MockServletContext servletContext;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setup() throws Exception {
|
public void setup() throws Exception {
|
||||||
token = new DefaultCsrfToken("header", "param", "token");
|
|
||||||
servletContext = new MockServletContext();
|
servletContext = new MockServletContext();
|
||||||
mockWebTestUtils();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void defaults() throws Exception {
|
public void defaults() throws Exception {
|
||||||
MockHttpServletRequest request = formLogin().buildRequest(servletContext);
|
MockHttpServletRequest request = formLogin().buildRequest(servletContext);
|
||||||
|
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
|
||||||
|
|
||||||
assertThat(request.getParameter("username")).isEqualTo("user");
|
assertThat(request.getParameter("username")).isEqualTo("user");
|
||||||
assertThat(request.getParameter("password")).isEqualTo("password");
|
assertThat(request.getParameter("password")).isEqualTo("password");
|
||||||
|
@ -64,8 +44,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
|
||||||
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
||||||
token.getToken());
|
token.getToken());
|
||||||
assertThat(request.getRequestURI()).isEqualTo("/login");
|
assertThat(request.getRequestURI()).isEqualTo("/login");
|
||||||
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
|
assertThat(request.getParameter("_csrf")).isNotNull();
|
||||||
any(HttpServletResponse.class));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -73,20 +52,13 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
|
||||||
MockHttpServletRequest request = formLogin("/login").user("username", "admin")
|
MockHttpServletRequest request = formLogin("/login").user("username", "admin")
|
||||||
.password("password", "secret").buildRequest(servletContext);
|
.password("password", "secret").buildRequest(servletContext);
|
||||||
|
|
||||||
|
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
|
||||||
|
|
||||||
assertThat(request.getParameter("username")).isEqualTo("admin");
|
assertThat(request.getParameter("username")).isEqualTo("admin");
|
||||||
assertThat(request.getParameter("password")).isEqualTo("secret");
|
assertThat(request.getParameter("password")).isEqualTo("secret");
|
||||||
assertThat(request.getMethod()).isEqualTo("POST");
|
assertThat(request.getMethod()).isEqualTo("POST");
|
||||||
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
||||||
token.getToken());
|
token.getToken());
|
||||||
assertThat(request.getRequestURI()).isEqualTo("/login");
|
assertThat(request.getRequestURI()).isEqualTo("/login");
|
||||||
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
|
|
||||||
any(HttpServletResponse.class));
|
|
||||||
}
|
|
||||||
|
|
||||||
private void mockWebTestUtils() throws Exception {
|
|
||||||
spy(WebTestUtils.class);
|
|
||||||
doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository",
|
|
||||||
any(HttpServletRequest.class));
|
|
||||||
when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,75 +15,47 @@
|
||||||
*/
|
*/
|
||||||
package org.springframework.security.test.web.servlet.request;
|
package org.springframework.security.test.web.servlet.request;
|
||||||
|
|
||||||
import static org.mockito.Matchers.any;
|
|
||||||
import static org.mockito.Matchers.eq;
|
|
||||||
import static org.mockito.Mockito.verify;
|
|
||||||
import static org.fest.assertions.Assertions.assertThat;
|
import static org.fest.assertions.Assertions.assertThat;
|
||||||
import static org.powermock.api.mockito.PowerMockito.spy;
|
|
||||||
import static org.powermock.api.mockito.PowerMockito.when;
|
|
||||||
import static org.powermock.api.mockito.PowerMockito.doReturn;
|
|
||||||
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
|
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
|
||||||
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
|
||||||
import javax.servlet.http.HttpServletResponse;
|
|
||||||
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
|
||||||
import org.mockito.Mock;
|
|
||||||
import org.powermock.core.classloader.annotations.PrepareForTest;
|
|
||||||
import org.powermock.modules.junit4.PowerMockRunner;
|
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.mock.web.MockServletContext;
|
import org.springframework.mock.web.MockServletContext;
|
||||||
import org.springframework.security.test.web.support.WebTestUtils;
|
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
|
||||||
import org.springframework.security.web.csrf.CsrfTokenRepository;
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
|
||||||
|
|
||||||
@RunWith(PowerMockRunner.class)
|
|
||||||
@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLogoutTests.class })
|
|
||||||
public class SecurityMockMvcRequestBuildersFormLogoutTests {
|
public class SecurityMockMvcRequestBuildersFormLogoutTests {
|
||||||
@Mock
|
|
||||||
private CsrfTokenRepository repository;
|
|
||||||
private DefaultCsrfToken token;
|
|
||||||
private MockServletContext servletContext;
|
private MockServletContext servletContext;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setup() {
|
public void setup() {
|
||||||
token = new DefaultCsrfToken("header", "param", "token");
|
|
||||||
servletContext = new MockServletContext();
|
servletContext = new MockServletContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void defaults() throws Exception {
|
public void defaults() throws Exception {
|
||||||
mockWebTestUtils();
|
|
||||||
MockHttpServletRequest request = logout().buildRequest(servletContext);
|
MockHttpServletRequest request = logout().buildRequest(servletContext);
|
||||||
|
|
||||||
|
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
|
||||||
|
|
||||||
assertThat(request.getMethod()).isEqualTo("POST");
|
assertThat(request.getMethod()).isEqualTo("POST");
|
||||||
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
||||||
token.getToken());
|
token.getToken());
|
||||||
assertThat(request.getRequestURI()).isEqualTo("/logout");
|
assertThat(request.getRequestURI()).isEqualTo("/logout");
|
||||||
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
|
|
||||||
any(HttpServletResponse.class));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void custom() throws Exception {
|
public void custom() throws Exception {
|
||||||
mockWebTestUtils();
|
|
||||||
MockHttpServletRequest request = logout("/admin/logout").buildRequest(
|
MockHttpServletRequest request = logout("/admin/logout").buildRequest(
|
||||||
servletContext);
|
servletContext);
|
||||||
|
|
||||||
|
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
|
||||||
|
|
||||||
assertThat(request.getMethod()).isEqualTo("POST");
|
assertThat(request.getMethod()).isEqualTo("POST");
|
||||||
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
||||||
token.getToken());
|
token.getToken());
|
||||||
assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
|
assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
|
||||||
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
|
|
||||||
any(HttpServletResponse.class));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void mockWebTestUtils() throws Exception {
|
|
||||||
spy(WebTestUtils.class);
|
|
||||||
doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository",
|
|
||||||
any(HttpServletRequest.class));
|
|
||||||
when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,12 +21,22 @@ import static org.springframework.security.test.web.servlet.setup.SecurityMockMv
|
||||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
|
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
|
||||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
|
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import javax.servlet.FilterChain;
|
||||||
|
import javax.servlet.ServletException;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.servlet.http.HttpServletRequestWrapper;
|
||||||
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
import javax.servlet.http.HttpSession;
|
||||||
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
|
import org.springframework.mock.web.MockHttpSession;
|
||||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||||
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
||||||
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
|
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
|
||||||
|
@ -40,6 +50,7 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders;
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
import org.springframework.web.context.WebApplicationContext;
|
import org.springframework.web.context.WebApplicationContext;
|
||||||
|
import org.springframework.web.filter.OncePerRequestFilter;
|
||||||
|
|
||||||
@RunWith(SpringJUnit4ClassRunner.class)
|
@RunWith(SpringJUnit4ClassRunner.class)
|
||||||
@ContextConfiguration
|
@ContextConfiguration
|
||||||
|
@ -86,6 +97,20 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
|
||||||
.andExpect(csrfAsHeader());
|
.andExpect(csrfAsHeader());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SEC-3097
|
||||||
|
@Test
|
||||||
|
public void csrfWithWrappedRequest() throws Exception {
|
||||||
|
mockMvc = MockMvcBuilders
|
||||||
|
.webAppContextSetup(wac)
|
||||||
|
.addFilter(new SessionRepositoryFilter())
|
||||||
|
.apply(springSecurity())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
mockMvc.perform(post("/").with(csrf()))
|
||||||
|
.andExpect(status().is2xxSuccessful())
|
||||||
|
.andExpect(csrfAsParam());
|
||||||
|
}
|
||||||
|
|
||||||
public static ResultMatcher csrfAsParam() {
|
public static ResultMatcher csrfAsParam() {
|
||||||
return new CsrfParamResultMatcher();
|
return new CsrfParamResultMatcher();
|
||||||
}
|
}
|
||||||
|
@ -112,6 +137,33 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static class SessionRepositoryFilter extends OncePerRequestFilter {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
||||||
|
throws ServletException, IOException {
|
||||||
|
filterChain.doFilter(new SessionRequestWrapper(request) , response);
|
||||||
|
}
|
||||||
|
|
||||||
|
static class SessionRequestWrapper extends HttpServletRequestWrapper {
|
||||||
|
HttpSession session = new MockHttpSession();
|
||||||
|
|
||||||
|
public SessionRequestWrapper(HttpServletRequest request) {
|
||||||
|
super(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public HttpSession getSession(boolean create) {
|
||||||
|
return session;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public HttpSession getSession() {
|
||||||
|
return session;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@EnableWebSecurity
|
@EnableWebSecurity
|
||||||
static class Config extends WebSecurityConfigurerAdapter {
|
static class Config extends WebSecurityConfigurerAdapter {
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue