Add test support for Xor CSRF tokens

Issue gh-4001
This commit is contained in:
Steve Riesenberg 2022-10-12 11:11:52 -05:00
parent 8bd25f90e4
commit 440748ec65
No known key found for this signature in database
GPG Key ID: 5F311AB48A55D521
5 changed files with 61 additions and 36 deletions

View File

@ -41,6 +41,7 @@ import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.context.junit.jupiter.SpringExtension;
@ -301,24 +302,7 @@ public class CsrfConfigTests {
} }
@Test @Test
public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorThenOk() throws Exception { public void postWhenUsingCsrfAndXorCsrfTokenRequestAttributeHandlerThenOk() throws Exception {
this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
.autowire();
// @formatter:off
MvcResult mvcResult = this.mvc.perform(get("/ok"))
.andExpect(status().isOk())
.andReturn();
MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
CsrfToken csrfToken = (CsrfToken) mvcResult.getRequest().getAttribute("_csrf");
MockHttpServletRequestBuilder ok = post("/ok")
.header(csrfToken.getHeaderName(), csrfToken.getToken())
.session(session);
this.mvc.perform(ok).andExpect(status().isOk());
// @formatter:on
}
@Test
public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorWithRawTokenThenForbidden() throws Exception {
this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers")) this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
.autowire(); .autowire();
// @formatter:off // @formatter:off
@ -329,6 +313,25 @@ public class CsrfConfigTests {
MockHttpServletRequestBuilder ok = post("/ok") MockHttpServletRequestBuilder ok = post("/ok")
.with(csrf()) .with(csrf())
.session(session); .session(session);
this.mvc.perform(ok).andExpect(status().isOk());
// @formatter:on
}
@Test
public void postWhenUsingCsrfAndXorCsrfTokenRequestAttributeHandlerWithRawTokenThenForbidden() throws Exception {
this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
.autowire();
// @formatter:off
MvcResult mvcResult = this.mvc.perform(get("/csrf"))
.andExpect(status().isOk())
.andReturn();
MockHttpServletRequest request = mvcResult.getRequest();
MockHttpSession session = (MockHttpSession) request.getSession();
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
CsrfToken csrfToken = repository.loadToken(request);
MockHttpServletRequestBuilder ok = post("/ok")
.header(csrfToken.getHeaderName(), csrfToken.getToken())
.session(session);
this.mvc.perform(ok).andExpect(status().isForbidden()); this.mvc.perform(ok).andExpect(status().isForbidden());
// @formatter:on // @formatter:on
} }
@ -594,7 +597,7 @@ public class CsrfConfigTests {
@Override @Override
public void match(MvcResult result) throws Exception { public void match(MvcResult result) throws Exception {
MockHttpServletRequest request = result.getRequest(); MockHttpServletRequest request = result.getRequest();
CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request); CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
assertThat(token).isNotNull(); assertThat(token).isNotNull();
assertThat(token.getToken()).isEqualTo(this.token.apply(result)); assertThat(token.getToken()).isEqualTo(this.token.apply(result));
} }

View File

@ -95,6 +95,8 @@ import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
@ -499,6 +501,10 @@ public final class SecurityMockMvcRequestPostProcessors {
*/ */
public static final class CsrfRequestPostProcessor implements RequestPostProcessor { public static final class CsrfRequestPostProcessor implements RequestPostProcessor {
private static final byte[] INVALID_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 };
private static final String INVALID_TOKEN_VALUE = Base64.getEncoder().encodeToString(INVALID_TOKEN_BYTES);
private boolean asHeader; private boolean asHeader;
private boolean useInvalidToken; private boolean useInvalidToken;
@ -509,14 +515,17 @@ public final class SecurityMockMvcRequestPostProcessors {
@Override @Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request);
if (!(repository instanceof TestCsrfTokenRepository)) { if (!(repository instanceof TestCsrfTokenRepository)) {
repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository()); repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository());
WebTestUtils.setCsrfTokenRepository(request, repository); WebTestUtils.setCsrfTokenRepository(request, repository);
} }
TestCsrfTokenRepository.enable(request); TestCsrfTokenRepository.enable(request);
CsrfToken token = repository.generateToken(request); MockHttpServletResponse response = new MockHttpServletResponse();
repository.saveToken(token, request, new MockHttpServletResponse()); DeferredCsrfToken deferredCsrfToken = repository.loadDeferredToken(request, response);
String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken(); handler.handle(request, response, deferredCsrfToken::get);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
String tokenValue = this.useInvalidToken ? INVALID_TOKEN_VALUE : token.getToken();
if (this.asHeader) { if (this.asHeader) {
request.addHeader(token.getHeaderName(), tokenValue); request.addHeader(token.getHeaderName(), tokenValue);
} }

View File

@ -31,7 +31,9 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils; import org.springframework.web.context.support.WebApplicationContextUtils;
@ -48,6 +50,8 @@ public abstract class WebTestUtils {
private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository(); private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository();
private static final CsrfTokenRequestHandler DEFAULT_CSRF_HANDLER = new XorCsrfTokenRequestAttributeHandler();
private WebTestUtils() { private WebTestUtils() {
} }
@ -107,6 +111,23 @@ public abstract class WebTestUtils {
return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository"); return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository");
} }
/**
* Gets the {@link CsrfTokenRequestHandler} for the specified
* {@link HttpServletRequest}. If one is not found, the default
* {@link XorCsrfTokenRequestAttributeHandler} is used.
* @param request the {@link HttpServletRequest} to obtain the
* {@link CsrfTokenRequestHandler}
* @return the {@link CsrfTokenRequestHandler} for the specified
* {@link HttpServletRequest}
*/
public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) {
CsrfFilter filter = findFilter(request, CsrfFilter.class);
if (filter == null) {
return DEFAULT_CSRF_HANDLER;
}
return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler");
}
/** /**
* Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}. * Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}.
* @param request the {@link HttpServletRequest} to obtain the * @param request the {@link HttpServletRequest} to obtain the

View File

@ -25,7 +25,6 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext; import org.springframework.mock.web.MockServletContext;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.MvcResult;
@ -52,8 +51,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
@Test @Test
public void defaults() { public void defaults() {
MockHttpServletRequest request = formLogin().buildRequest(this.servletContext); MockHttpServletRequest request = formLogin().buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("username")).isEqualTo("user");
assertThat(request.getParameter("password")).isEqualTo("password"); assertThat(request.getParameter("password")).isEqualTo("password");
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
@ -66,8 +64,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
public void custom() { public void custom() {
MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret") MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret")
.buildRequest(this.servletContext); .buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("username")).isEqualTo("admin");
assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getParameter("password")).isEqualTo("secret");
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
@ -79,8 +76,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
public void customWithUriVars() { public void customWithUriVars() {
MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2") MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2")
.user("username", "admin").password("password", "secret").buildRequest(this.servletContext); .user("username", "admin").password("password", "secret").buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("username")).isEqualTo("admin");
assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getParameter("password")).isEqualTo("secret");
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");

View File

@ -25,7 +25,6 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext; import org.springframework.mock.web.MockServletContext;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.MvcResult;
@ -52,8 +51,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
@Test @Test
public void defaults() { public void defaults() {
MockHttpServletRequest request = logout().buildRequest(this.servletContext); MockHttpServletRequest request = logout().buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/logout"); assertThat(request.getRequestURI()).isEqualTo("/logout");
@ -62,8 +60,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
@Test @Test
public void custom() { public void custom() {
MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext); MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/admin/logout"); assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
@ -73,8 +70,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
public void customWithUriVars() { public void customWithUriVars() {
MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2") MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2")
.buildRequest(this.servletContext); .buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2"); assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");