SEC-3097: Change CsrfRequestPostProcessor to use TestCsrfTokenRepository
This ensures that when using a wrapped HttpServletRequest (i.e. Spring Session) that the CSRF token test support still works.
This commit is contained in:
parent
ea94706319
commit
81e2778106
|
@ -316,6 +316,10 @@ public final class SecurityMockMvcRequestPostProcessors {
|
|||
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
|
||||
|
||||
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
|
||||
if(!(repository instanceof TestCsrfTokenRepository)) {
|
||||
repository = new TestCsrfTokenRepository(repository);
|
||||
WebTestUtils.setCsrfTokenRepository(request, repository);
|
||||
}
|
||||
CsrfToken token = repository.generateToken(request);
|
||||
repository.saveToken(token, request, new MockHttpServletResponse());
|
||||
String tokenValue = useInvalidToken ? "invalid" + token.getToken() : token
|
||||
|
@ -352,6 +356,36 @@ public final class SecurityMockMvcRequestPostProcessors {
|
|||
|
||||
private CsrfRequestPostProcessor() {
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Used to wrap the CsrfTokenRepository to provide support for testing
|
||||
* when the request is wrapped (i.e. Spring Session is in use).
|
||||
*/
|
||||
static class TestCsrfTokenRepository implements
|
||||
CsrfTokenRepository {
|
||||
final static String ATTR_NAME = TestCsrfTokenRepository.class
|
||||
.getName().concat(".TOKEN");
|
||||
|
||||
private final CsrfTokenRepository delegate;
|
||||
|
||||
private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
public CsrfToken generateToken(HttpServletRequest request) {
|
||||
return delegate.generateToken(request);
|
||||
}
|
||||
|
||||
public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
|
||||
request.setAttribute(ATTR_NAME, token);
|
||||
}
|
||||
|
||||
public CsrfToken loadToken(HttpServletRequest request) {
|
||||
return (CsrfToken) request.getAttribute(ATTR_NAME);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static class DigestRequestPostProcessor implements RequestPostProcessor {
|
||||
|
|
|
@ -97,6 +97,22 @@ public abstract class WebTestUtils {
|
|||
"tokenRepository");
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link CsrfTokenRepository} for the specified
|
||||
* {@link HttpServletRequest}.
|
||||
*
|
||||
* @param request the {@link HttpServletRequest} to obtain the
|
||||
* {@link CsrfTokenRepository}
|
||||
* @param repository the {@link CsrfTokenRepository} to set
|
||||
*/
|
||||
public static void setCsrfTokenRepository(HttpServletRequest request,
|
||||
CsrfTokenRepository repository) {
|
||||
CsrfFilter filter = findFilter(request, CsrfFilter.class);
|
||||
if (filter != null) {
|
||||
ReflectionTestUtils.setField(filter, "tokenRepository", repository);
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static <T extends Filter> T findFilter(HttpServletRequest request,
|
||||
Class<T> filterClass) {
|
||||
|
|
|
@ -15,48 +15,28 @@
|
|||
*/
|
||||
package org.springframework.security.test.web.servlet.request;
|
||||
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.fest.assertions.Assertions.assertThat;
|
||||
import static org.powermock.api.mockito.PowerMockito.spy;
|
||||
import static org.powermock.api.mockito.PowerMockito.when;
|
||||
import static org.powermock.api.mockito.PowerMockito.doReturn;
|
||||
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.powermock.core.classloader.annotations.PrepareForTest;
|
||||
import org.powermock.modules.junit4.PowerMockRunner;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockServletContext;
|
||||
import org.springframework.security.test.web.support.WebTestUtils;
|
||||
import org.springframework.security.web.csrf.CsrfTokenRepository;
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
|
||||
@RunWith(PowerMockRunner.class)
|
||||
@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLoginTests.class })
|
||||
public class SecurityMockMvcRequestBuildersFormLoginTests {
|
||||
@Mock
|
||||
private CsrfTokenRepository repository;
|
||||
private DefaultCsrfToken token;
|
||||
private MockServletContext servletContext;
|
||||
|
||||
@Before
|
||||
public void setup() throws Exception {
|
||||
token = new DefaultCsrfToken("header", "param", "token");
|
||||
servletContext = new MockServletContext();
|
||||
mockWebTestUtils();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void defaults() throws Exception {
|
||||
MockHttpServletRequest request = formLogin().buildRequest(servletContext);
|
||||
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
|
||||
|
||||
assertThat(request.getParameter("username")).isEqualTo("user");
|
||||
assertThat(request.getParameter("password")).isEqualTo("password");
|
||||
|
@ -64,8 +44,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
|
|||
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
||||
token.getToken());
|
||||
assertThat(request.getRequestURI()).isEqualTo("/login");
|
||||
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
|
||||
any(HttpServletResponse.class));
|
||||
assertThat(request.getParameter("_csrf")).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -73,20 +52,13 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
|
|||
MockHttpServletRequest request = formLogin("/login").user("username", "admin")
|
||||
.password("password", "secret").buildRequest(servletContext);
|
||||
|
||||
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
|
||||
|
||||
assertThat(request.getParameter("username")).isEqualTo("admin");
|
||||
assertThat(request.getParameter("password")).isEqualTo("secret");
|
||||
assertThat(request.getMethod()).isEqualTo("POST");
|
||||
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
||||
token.getToken());
|
||||
assertThat(request.getRequestURI()).isEqualTo("/login");
|
||||
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
|
||||
any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
private void mockWebTestUtils() throws Exception {
|
||||
spy(WebTestUtils.class);
|
||||
doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository",
|
||||
any(HttpServletRequest.class));
|
||||
when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,75 +15,47 @@
|
|||
*/
|
||||
package org.springframework.security.test.web.servlet.request;
|
||||
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.fest.assertions.Assertions.assertThat;
|
||||
import static org.powermock.api.mockito.PowerMockito.spy;
|
||||
import static org.powermock.api.mockito.PowerMockito.when;
|
||||
import static org.powermock.api.mockito.PowerMockito.doReturn;
|
||||
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.powermock.core.classloader.annotations.PrepareForTest;
|
||||
import org.powermock.modules.junit4.PowerMockRunner;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockServletContext;
|
||||
import org.springframework.security.test.web.support.WebTestUtils;
|
||||
import org.springframework.security.web.csrf.CsrfTokenRepository;
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
|
||||
@RunWith(PowerMockRunner.class)
|
||||
@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLogoutTests.class })
|
||||
public class SecurityMockMvcRequestBuildersFormLogoutTests {
|
||||
@Mock
|
||||
private CsrfTokenRepository repository;
|
||||
private DefaultCsrfToken token;
|
||||
private MockServletContext servletContext;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
token = new DefaultCsrfToken("header", "param", "token");
|
||||
servletContext = new MockServletContext();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void defaults() throws Exception {
|
||||
mockWebTestUtils();
|
||||
MockHttpServletRequest request = logout().buildRequest(servletContext);
|
||||
|
||||
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
|
||||
|
||||
assertThat(request.getMethod()).isEqualTo("POST");
|
||||
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
||||
token.getToken());
|
||||
assertThat(request.getRequestURI()).isEqualTo("/logout");
|
||||
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
|
||||
any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void custom() throws Exception {
|
||||
mockWebTestUtils();
|
||||
MockHttpServletRequest request = logout("/admin/logout").buildRequest(
|
||||
servletContext);
|
||||
|
||||
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
|
||||
|
||||
assertThat(request.getMethod()).isEqualTo("POST");
|
||||
assertThat(request.getParameter(token.getParameterName())).isEqualTo(
|
||||
token.getToken());
|
||||
assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
|
||||
verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
|
||||
any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
private void mockWebTestUtils() throws Exception {
|
||||
spy(WebTestUtils.class);
|
||||
doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository",
|
||||
any(HttpServletRequest.class));
|
||||
when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,12 +21,22 @@ import static org.springframework.security.test.web.servlet.setup.SecurityMockMv
|
|||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
|
||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletRequestWrapper;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.servlet.http.HttpSession;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockHttpSession;
|
||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
||||
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
|
||||
|
@ -40,6 +50,7 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders;
|
|||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import org.springframework.web.context.WebApplicationContext;
|
||||
import org.springframework.web.filter.OncePerRequestFilter;
|
||||
|
||||
@RunWith(SpringJUnit4ClassRunner.class)
|
||||
@ContextConfiguration
|
||||
|
@ -86,6 +97,20 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
|
|||
.andExpect(csrfAsHeader());
|
||||
}
|
||||
|
||||
// SEC-3097
|
||||
@Test
|
||||
public void csrfWithWrappedRequest() throws Exception {
|
||||
mockMvc = MockMvcBuilders
|
||||
.webAppContextSetup(wac)
|
||||
.addFilter(new SessionRepositoryFilter())
|
||||
.apply(springSecurity())
|
||||
.build();
|
||||
|
||||
mockMvc.perform(post("/").with(csrf()))
|
||||
.andExpect(status().is2xxSuccessful())
|
||||
.andExpect(csrfAsParam());
|
||||
}
|
||||
|
||||
public static ResultMatcher csrfAsParam() {
|
||||
return new CsrfParamResultMatcher();
|
||||
}
|
||||
|
@ -112,6 +137,33 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
|
|||
}
|
||||
}
|
||||
|
||||
static class SessionRepositoryFilter extends OncePerRequestFilter {
|
||||
|
||||
@Override
|
||||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
||||
throws ServletException, IOException {
|
||||
filterChain.doFilter(new SessionRequestWrapper(request) , response);
|
||||
}
|
||||
|
||||
static class SessionRequestWrapper extends HttpServletRequestWrapper {
|
||||
HttpSession session = new MockHttpSession();
|
||||
|
||||
public SessionRequestWrapper(HttpServletRequest request) {
|
||||
super(request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpSession getSession(boolean create) {
|
||||
return session;
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpSession getSession() {
|
||||
return session;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@EnableWebSecurity
|
||||
static class Config extends WebSecurityConfigurerAdapter {
|
||||
@Override
|
||||
|
|
Loading…
Reference in New Issue