Fix csrf() when used then not used

Previously if csrf() was used and subsequently not used, the
TestCsrfTokenRepository was still used. This makes it difficult to test
the actual CsrfTokenRepository implementation.

Now the TestCsrfTokenRepository is only used if explicitly enabled.

Fixes gh-4016
This commit is contained in:
Rob Winch 2016-08-09 15:00:06 -05:00 committed by Joe Grandja
parent 519c15efb3
commit 050198e51b
4 changed files with 52 additions and 8 deletions

View File

@ -318,13 +318,13 @@ public final class SecurityMockMvcRequestPostProcessors {
*/ */
@Override @Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
if (!(repository instanceof TestCsrfTokenRepository)) { if (!(repository instanceof TestCsrfTokenRepository)) {
repository = new TestCsrfTokenRepository( repository = new TestCsrfTokenRepository(
new HttpSessionCsrfTokenRepository()); new HttpSessionCsrfTokenRepository());
WebTestUtils.setCsrfTokenRepository(request, repository); WebTestUtils.setCsrfTokenRepository(request, repository);
} }
TestCsrfTokenRepository.enable(request);
CsrfToken token = repository.generateToken(request); CsrfToken token = repository.generateToken(request);
repository.saveToken(token, request, new MockHttpServletResponse()); repository.saveToken(token, request, new MockHttpServletResponse());
String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() String tokenValue = this.useInvalidToken ? "invalid" + token.getToken()
@ -367,9 +367,12 @@ public final class SecurityMockMvcRequestPostProcessors {
* request is wrapped (i.e. Spring Session is in use). * request is wrapped (i.e. Spring Session is in use).
*/ */
static class TestCsrfTokenRepository implements CsrfTokenRepository { static class TestCsrfTokenRepository implements CsrfTokenRepository {
final static String ATTR_NAME = TestCsrfTokenRepository.class.getName() final static String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName()
.concat(".TOKEN"); .concat(".TOKEN");
final static String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class
.getName().concat(".ENABLED");
private final CsrfTokenRepository delegate; private final CsrfTokenRepository delegate;
private TestCsrfTokenRepository(CsrfTokenRepository delegate) { private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
@ -384,12 +387,30 @@ public final class SecurityMockMvcRequestPostProcessors {
@Override @Override
public void saveToken(CsrfToken token, HttpServletRequest request, public void saveToken(CsrfToken token, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
request.setAttribute(ATTR_NAME, token); if (isEnabled(request)) {
request.setAttribute(TOKEN_ATTR_NAME, token);
}
else {
this.delegate.saveToken(token, request, response);
}
} }
@Override @Override
public CsrfToken loadToken(HttpServletRequest request) { public CsrfToken loadToken(HttpServletRequest request) {
return (CsrfToken) request.getAttribute(ATTR_NAME); if (isEnabled(request)) {
return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME);
}
else {
return this.delegate.loadToken(request);
}
}
public static void enable(HttpServletRequest request) {
request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE);
}
public boolean isEnabled(HttpServletRequest request) {
return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME));
} }
} }
} }

View File

@ -39,7 +39,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
public void defaults() throws Exception { public void defaults() throws Exception {
MockHttpServletRequest request = formLogin().buildRequest(this.servletContext); MockHttpServletRequest request = formLogin().buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("username")).isEqualTo("user");
assertThat(request.getParameter("password")).isEqualTo("password"); assertThat(request.getParameter("password")).isEqualTo("password");
@ -56,7 +56,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
.password("password", "secret").buildRequest(this.servletContext); .password("password", "secret").buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("username")).isEqualTo("admin");
assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getParameter("password")).isEqualTo("secret");

View File

@ -37,7 +37,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
public void defaults() throws Exception { public void defaults() throws Exception {
MockHttpServletRequest request = logout().buildRequest(servletContext); MockHttpServletRequest request = logout().buildRequest(servletContext);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo( assertThat(request.getParameter(token.getParameterName())).isEqualTo(
@ -50,7 +50,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
MockHttpServletRequest request = logout("/admin/logout").buildRequest( MockHttpServletRequest request = logout("/admin/logout").buildRequest(
servletContext); servletContext);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo( assertThat(request.getParameter(token.getParameterName())).isEqualTo(

View File

@ -31,18 +31,22 @@ import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession; import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessorsCsrfTests.Config.TheController; import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessorsCsrfTests.Config.TheController;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.web.WebAppConfiguration; import org.springframework.test.context.web.WebAppConfiguration;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.ResultMatcher; import org.springframework.test.web.servlet.ResultMatcher;
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
@ -143,6 +147,25 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
// @formatter:on // @formatter:on
} }
// gh-4016
@Test
public void csrfWhenUsedThenDoesNotImpactOriginalRepository() throws Exception {
// @formatter:off
this.mockMvc.perform(post("/").with(csrf()));
MockHttpServletRequest request = new MockHttpServletRequest();
HttpSessionCsrfTokenRepository repo = new HttpSessionCsrfTokenRepository();
CsrfToken token = repo.generateToken(request);
repo.saveToken(token, request, new MockHttpServletResponse());
MockHttpServletRequestBuilder requestWithCsrf = post("/")
.param(token.getParameterName(), token.getToken())
.session((MockHttpSession)request.getSession());
this.mockMvc.perform(requestWithCsrf)
.andExpect(status().isOk());
// @formatter:on
}
public static ResultMatcher csrfAsParam() { public static ResultMatcher csrfAsParam() {
return new CsrfParamResultMatcher(); return new CsrfParamResultMatcher();
} }