Merge branch '5.8.x'

Closes gh-dry-run
This commit is contained in:
Steve Riesenberg 2022-10-04 11:18:00 -05:00
commit 5de6da890b
No known key found for this signature in database
GPG Key ID: 5F311AB48A55D521
29 changed files with 540 additions and 351 deletions

View File

@ -36,7 +36,6 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler; import org.springframework.security.web.csrf.CsrfLogoutHandler;
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository; import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
@ -249,13 +248,7 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override @Override
public void configure(H http) { public void configure(H http) {
CsrfFilter filter; CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository);
if (this.requestHandler != null) {
filter = new CsrfFilter(this.requestHandler);
}
else {
filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.csrfTokenRepository));
}
RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher(); RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
if (requireCsrfProtectionMatcher != null) { if (requireCsrfProtectionMatcher != null) {
filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
@ -272,6 +265,9 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
if (sessionConfigurer != null) { if (sessionConfigurer != null) {
sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy()); sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy());
} }
if (this.requestHandler != null) {
filter.setRequestHandler(this.requestHandler);
}
filter = postProcess(filter); filter = postProcess(filter);
http.addFilter(filter); http.addFilter(filter);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -40,7 +40,6 @@ import org.springframework.security.web.access.DelegatingAccessDeniedHandler;
import org.springframework.security.web.csrf.CsrfAuthenticationStrategy; import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler; import org.springframework.security.web.csrf.CsrfLogoutHandler;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository; import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.security.web.csrf.MissingCsrfTokenException;
@ -111,18 +110,13 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
new BeanComponentDefinition(lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef)); new BeanComponentDefinition(lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef));
} }
BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(CsrfFilter.class); BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(CsrfFilter.class);
if (!StringUtils.hasText(this.requestHandlerRef)) { builder.addConstructorArgReference(this.csrfRepositoryRef);
BeanDefinition csrfTokenRequestHandler = BeanDefinitionBuilder
.rootBeanDefinition(CsrfTokenRepositoryRequestHandler.class)
.addConstructorArgReference(this.csrfRepositoryRef).getBeanDefinition();
builder.addConstructorArgValue(csrfTokenRequestHandler);
}
else {
builder.addConstructorArgReference(this.requestHandlerRef);
}
if (StringUtils.hasText(this.requestMatcherRef)) { if (StringUtils.hasText(this.requestMatcherRef)) {
builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef); builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef);
} }
if (StringUtils.hasText(this.requestHandlerRef)) {
builder.addPropertyReference("requestHandler", this.requestHandlerRef);
}
this.csrfFilter = builder.getBeanDefinition(); this.csrfFilter = builder.getBeanDefinition();
return this.csrfFilter; return this.csrfFilter;
} }

View File

@ -1152,7 +1152,7 @@ csrf-options.attlist &=
## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository. ## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository.
attribute token-repository-ref { xsd:token }? attribute token-repository-ref { xsd:token }?
csrf-options.attlist &= csrf-options.attlist &=
## The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler. ## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestAttributeHandler.
attribute request-handler-ref { xsd:token }? attribute request-handler-ref { xsd:token }?
headers = headers =

View File

@ -3258,7 +3258,7 @@
</xs:attribute> </xs:attribute>
<xs:attribute name="request-handler-ref" type="xs:token"> <xs:attribute name="request-handler-ref" type="xs:token">
<xs:annotation> <xs:annotation>
<xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler. <xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestAttributeHandler.
</xs:documentation> </xs:documentation>
</xs:annotation> </xs:annotation>
</xs:attribute> </xs:attribute>

View File

@ -43,8 +43,10 @@ import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler; import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.firewall.StrictHttpFirewall; import org.springframework.security.web.firewall.StrictHttpFirewall;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.RequestCache;
@ -62,11 +64,11 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
@ -209,23 +211,23 @@ public class CsrfConfigurerTests {
public void loginWhenCsrfEnabledThenDoesNotRedirectToPreviousPostRequest() throws Exception { public void loginWhenCsrfEnabledThenDoesNotRedirectToPreviousPostRequest() throws Exception {
CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class); CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class);
DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken); given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken); any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire(); this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire();
MvcResult mvcResult = this.mvc.perform(post("/some-url")).andReturn(); MvcResult mvcResult = this.mvc.perform(post("/some-url")).andReturn();
this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf()) this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf())
.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound()) .session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
.andExpect(redirectedUrl("/")); .andExpect(redirectedUrl("/"));
verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce()) verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce())
.loadToken(any(HttpServletRequest.class)); .loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
} }
@Test @Test
public void loginWhenCsrfEnabledThenRedirectsToPreviousGetRequest() throws Exception { public void loginWhenCsrfEnabledThenRedirectsToPreviousGetRequest() throws Exception {
CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class); CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class);
DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken); given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken); any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire(); this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire();
MvcResult mvcResult = this.mvc.perform(get("/some-url")).andReturn(); MvcResult mvcResult = this.mvc.perform(get("/some-url")).andReturn();
RequestCache requestCache = new HttpSessionRequestCache(); RequestCache requestCache = new HttpSessionRequestCache();
@ -233,6 +235,8 @@ public class CsrfConfigurerTests {
this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf()) this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf())
.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound()) .session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
.andExpect(redirectedUrl(redirectUrl)); .andExpect(redirectedUrl(redirectUrl));
verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce())
.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
} }
// SEC-2422 // SEC-2422
@ -279,11 +283,13 @@ public class CsrfConfigurerTests {
@Test @Test
public void postWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception { public void postWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception {
CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class);
given(CsrfTokenRepositoryConfig.REPO.loadToken(any())) given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
.willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")); any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")));
this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire();
this.mvc.perform(post("/")); this.mvc.perform(post("/"));
verify(CsrfTokenRepositoryConfig.REPO).loadToken(any(HttpServletRequest.class)); verify(CsrfTokenRepositoryConfig.REPO).loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class));
} }
@Test @Test
@ -299,8 +305,8 @@ public class CsrfConfigurerTests {
public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception { public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception {
CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class);
DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(CsrfTokenRepositoryConfig.REPO.loadToken(any())).willReturn(csrfToken); given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
given(CsrfTokenRepositoryConfig.REPO.generateToken(any())).willReturn(csrfToken); any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire();
// @formatter:off // @formatter:off
MockHttpServletRequestBuilder loginRequest = post("/login") MockHttpServletRequestBuilder loginRequest = post("/login")
@ -316,11 +322,13 @@ public class CsrfConfigurerTests {
@Test @Test
public void getWhenCustomCsrfTokenRepositoryInLambdaThenRepositoryIsUsed() throws Exception { public void getWhenCustomCsrfTokenRepositoryInLambdaThenRepositoryIsUsed() throws Exception {
CsrfTokenRepositoryInLambdaConfig.REPO = mock(CsrfTokenRepository.class); CsrfTokenRepositoryInLambdaConfig.REPO = mock(CsrfTokenRepository.class);
given(CsrfTokenRepositoryInLambdaConfig.REPO.loadToken(any())) given(CsrfTokenRepositoryInLambdaConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
.willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")); any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")));
this.spring.register(CsrfTokenRepositoryInLambdaConfig.class, BasicController.class).autowire(); this.spring.register(CsrfTokenRepositoryInLambdaConfig.class, BasicController.class).autowire();
this.mvc.perform(post("/")); this.mvc.perform(post("/"));
verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadToken(any(HttpServletRequest.class)); verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class));
} }
@Test @Test
@ -420,30 +428,30 @@ public class CsrfConfigurerTests {
} }
@Test @Test
public void getLoginWhenCsrfTokenRequestProcessorSetThenRespondsWithNormalCsrfToken() throws Exception { public void getLoginWhenCsrfTokenRequestHandlerSetThenRespondsWithNormalCsrfToken() throws Exception {
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken); given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository); .willReturn(new TestDeferredCsrfToken(csrfToken));
this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire(); CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler();
this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
this.mvc.perform(get("/login")).andExpect(status().isOk()) this.mvc.perform(get("/login")).andExpect(status().isOk())
.andExpect(content().string(containsString(csrfToken.getToken()))); .andExpect(content().string(containsString(csrfToken.getToken())));
verify(csrfTokenRepository).loadToken(any(HttpServletRequest.class)); verify(csrfTokenRepository).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class));
verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class),
any(HttpServletResponse.class));
verifyNoMoreInteractions(csrfTokenRepository); verifyNoMoreInteractions(csrfTokenRepository);
} }
@Test @Test
public void loginWhenCsrfTokenRequestProcessorSetAndNormalCsrfTokenThenSuccess() throws Exception { public void loginWhenCsrfTokenRequestHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception {
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken); given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken); .willReturn(new TestDeferredCsrfToken(csrfToken));
CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository); CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler();
this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
// @formatter:off // @formatter:off
MockHttpServletRequestBuilder loginRequest = post("/login") MockHttpServletRequestBuilder loginRequest = post("/login")
.header(csrfToken.getHeaderName(), csrfToken.getToken()) .header(csrfToken.getHeaderName(), csrfToken.getToken())
@ -451,9 +459,8 @@ public class CsrfConfigurerTests {
.param("password", "password"); .param("password", "password");
// @formatter:on // @formatter:on
this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); this.mvc.perform(loginRequest).andExpect(redirectedUrl("/"));
verify(csrfTokenRepository).loadToken(any(HttpServletRequest.class)); verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class));
verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class)); verify(csrfTokenRepository, times(2)).loadDeferredToken(any(HttpServletRequest.class),
verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class),
any(HttpServletResponse.class)); any(HttpServletResponse.class));
verifyNoMoreInteractions(csrfTokenRepository); verifyNoMoreInteractions(csrfTokenRepository);
} }
@ -819,9 +826,11 @@ public class CsrfConfigurerTests {
@Configuration @Configuration
@EnableWebSecurity @EnableWebSecurity
static class CsrfTokenRequestProcessorConfig { static class CsrfTokenRequestHandlerConfig {
static CsrfTokenRepositoryRequestHandler HANDLER; static CsrfTokenRepository REPO;
static CsrfTokenRequestHandler HANDLER;
@Bean @Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
@ -831,7 +840,10 @@ public class CsrfConfigurerTests {
.anyRequest().authenticated() .anyRequest().authenticated()
) )
.formLogin(Customizer.withDefaults()) .formLogin(Customizer.withDefaults())
.csrf((csrf) -> csrf.csrfTokenRequestHandler(HANDLER)); .csrf((csrf) -> csrf
.csrfTokenRepository(REPO)
.csrfTokenRequestHandler(HANDLER)
);
// @formatter:on // @formatter:on
return http.build(); return http.build();
@ -861,4 +873,24 @@ public class CsrfConfigurerTests {
} }
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken;
private TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}
@Override
public CsrfToken get() {
return this.csrfToken;
}
@Override
public boolean isGenerated() {
return false;
}
}
} }

View File

@ -29,7 +29,6 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession; import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContext;
@ -42,7 +41,6 @@ import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.context.junit.jupiter.SpringExtension;
@ -546,9 +544,8 @@ public class CsrfConfigTests {
@Override @Override
public void match(MvcResult result) { public void match(MvcResult result) {
MockHttpServletRequest request = result.getRequest(); MockHttpServletRequest request = result.getRequest();
MockHttpServletResponse response = result.getResponse(); CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
DeferredCsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response); assertThat(token).isNotNull();
assertThat(token.isGenerated()).isFalse();
} }
} }
@ -564,8 +561,7 @@ public class CsrfConfigTests {
@Override @Override
public void match(MvcResult result) throws Exception { public void match(MvcResult result) throws Exception {
MockHttpServletRequest request = result.getRequest(); MockHttpServletRequest request = result.getRequest();
MockHttpServletResponse response = result.getResponse(); CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
CsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response).get();
assertThat(token).isNotNull(); assertThat(token).isNotNull();
assertThat(token.getToken()).isEqualTo(this.token.apply(result)); assertThat(token.getToken()).isEqualTo(this.token.apply(result));
} }

View File

@ -26,7 +26,7 @@
<csrf request-handler-ref="requestHandler"/> <csrf request-handler-ref="requestHandler"/>
</http> </http>
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler" <b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler"
p:csrfRequestAttributeName="csrf-attribute-name"/> p:csrfRequestAttributeName="csrf-attribute-name"/>
<b:import resource="CsrfConfigTests-shared-userservice.xml"/> <b:import resource="CsrfConfigTests-shared-userservice.xml"/>
</b:beans> </b:beans>

View File

@ -782,7 +782,7 @@ The default is `HttpSessionCsrfTokenRepository`.
[[nsa-csrf-request-handler-ref]] [[nsa-csrf-request-handler-ref]]
* **request-handler-ref** * **request-handler-ref**
The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRepositoryRequestHandler`. The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRequestAttributeHandler`.
[[nsa-csrf-request-matcher-ref]] [[nsa-csrf-request-matcher-ref]]
* **request-matcher-ref** * **request-matcher-ref**

View File

@ -94,8 +94,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
@ -509,13 +508,14 @@ public final class SecurityMockMvcRequestPostProcessors {
@Override @Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request); CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
if (!(handler instanceof TestCsrfTokenRequestHandler)) { if (!(repository instanceof TestCsrfTokenRepository)) {
handler = new TestCsrfTokenRequestHandler(handler); repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository());
WebTestUtils.setCsrfTokenRequestHandler(request, handler); WebTestUtils.setCsrfTokenRepository(request, repository);
} }
TestCsrfTokenRequestHandler testHandler = (TestCsrfTokenRequestHandler) handler; TestCsrfTokenRepository.enable(request);
CsrfToken token = TestCsrfTokenRequestHandler.createTestCsrfToken(request); CsrfToken token = repository.generateToken(request);
repository.saveToken(token, request, new MockHttpServletResponse());
String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken(); String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken();
if (this.asHeader) { if (this.asHeader) {
request.addHeader(token.getHeaderName(), tokenValue); request.addHeader(token.getHeaderName(), tokenValue);
@ -549,56 +549,49 @@ public final class SecurityMockMvcRequestPostProcessors {
* Used to wrap the CsrfTokenRepository to provide support for testing when the * Used to wrap the CsrfTokenRepository to provide support for testing when the
* request is wrapped (i.e. Spring Session is in use). * request is wrapped (i.e. Spring Session is in use).
*/ */
static class TestCsrfTokenRequestHandler implements CsrfTokenRequestHandler { static class TestCsrfTokenRepository implements CsrfTokenRepository {
static final String TOKEN_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".TOKEN"); static final String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".TOKEN");
static final String ENABLED_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".ENABLED"); static final String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".ENABLED");
private final CsrfTokenRequestHandler delegate; private final CsrfTokenRepository delegate;
TestCsrfTokenRequestHandler(CsrfTokenRequestHandler delegate) { TestCsrfTokenRepository(CsrfTokenRepository delegate) {
this.delegate = delegate; this.delegate = delegate;
} }
static CsrfToken createTestCsrfToken(HttpServletRequest request) { @Override
CsrfToken existingToken = getExistingToken(request); public CsrfToken generateToken(HttpServletRequest request) {
if (existingToken != null) { return this.delegate.generateToken(request);
return existingToken;
}
HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository();
CsrfToken csrfToken = repository.generateToken(request);
request.setAttribute(ENABLED_ATTR_NAME, true);
request.setAttribute(TOKEN_ATTR_NAME, csrfToken);
return csrfToken;
}
private static CsrfToken getExistingToken(HttpServletRequest request) {
Object existingToken = request.getAttribute(TOKEN_ATTR_NAME);
return (CsrfToken) existingToken;
}
boolean isEnabled(HttpServletRequest request) {
return getExistingToken(request) != null;
} }
@Override @Override
public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) { public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
request.setAttribute(HttpServletResponse.class.getName(), response); if (isEnabled(request)) {
if (!isEnabled(request)) { request.setAttribute(TOKEN_ATTR_NAME, token);
return this.delegate.handle(request, response);
} }
return new DeferredCsrfToken() { else {
@Override this.delegate.saveToken(token, request, response);
public CsrfToken get() { }
return getExistingToken(request); }
}
@Override @Override
public boolean isGenerated() { public CsrfToken loadToken(HttpServletRequest request) {
return false; if (isEnabled(request)) {
} return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME);
}; }
else {
return this.delegate.loadToken(request);
}
}
static void enable(HttpServletRequest request) {
request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE);
}
boolean isEnabled(HttpServletRequest request) {
return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME));
} }
} }

View File

@ -31,8 +31,6 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
@ -48,7 +46,7 @@ public abstract class WebTestUtils {
private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository(); private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository();
private static final CsrfTokenRepositoryRequestHandler DEFAULT_CSRF_HANDLER = new CsrfTokenRepositoryRequestHandler(); private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository();
private WebTestUtils() { private WebTestUtils() {
} }
@ -101,24 +99,24 @@ public abstract class WebTestUtils {
* @return the {@link CsrfTokenRepository} for the specified * @return the {@link CsrfTokenRepository} for the specified
* {@link HttpServletRequest} * {@link HttpServletRequest}
*/ */
public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) { public static CsrfTokenRepository getCsrfTokenRepository(HttpServletRequest request) {
CsrfFilter filter = findFilter(request, CsrfFilter.class); CsrfFilter filter = findFilter(request, CsrfFilter.class);
if (filter == null) { if (filter == null) {
return DEFAULT_CSRF_HANDLER; return DEFAULT_TOKEN_REPO;
} }
return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler"); return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository");
} }
/** /**
* Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}. * Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}.
* @param request the {@link HttpServletRequest} to obtain the * @param request the {@link HttpServletRequest} to obtain the
* {@link CsrfTokenRepository} * {@link CsrfTokenRepository}
* @param handler the {@link CsrfTokenRepository} to set * @param repository the {@link CsrfTokenRepository} to set
*/ */
public static void setCsrfTokenRequestHandler(HttpServletRequest request, CsrfTokenRequestHandler handler) { public static void setCsrfTokenRepository(HttpServletRequest request, CsrfTokenRepository repository) {
CsrfFilter filter = findFilter(request, CsrfFilter.class); CsrfFilter filter = findFilter(request, CsrfFilter.class);
if (filter != null) { if (filter != null) {
ReflectionTestUtils.setField(filter, "requestHandler", handler); ReflectionTestUtils.setField(filter, "tokenRepository", repository);
} }
} }

View File

@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
public void defaults() { public void defaults() {
MockHttpServletRequest request = formLogin().buildRequest(this.servletContext); MockHttpServletRequest request = formLogin().buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("username")).isEqualTo("user");
assertThat(request.getParameter("password")).isEqualTo("password"); assertThat(request.getParameter("password")).isEqualTo("password");
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
@ -67,7 +67,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret") MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret")
.buildRequest(this.servletContext); .buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("username")).isEqualTo("admin");
assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getParameter("password")).isEqualTo("secret");
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
@ -80,7 +80,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2") MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2")
.user("username", "admin").password("password", "secret").buildRequest(this.servletContext); .user("username", "admin").password("password", "secret").buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("username")).isEqualTo("admin");
assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getParameter("password")).isEqualTo("secret");
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");

View File

@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
public void defaults() { public void defaults() {
MockHttpServletRequest request = logout().buildRequest(this.servletContext); MockHttpServletRequest request = logout().buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/logout"); assertThat(request.getRequestURI()).isEqualTo("/logout");
@ -63,7 +63,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
public void custom() { public void custom() {
MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext); MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/admin/logout"); assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
@ -74,7 +74,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2") MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2")
.buildRequest(this.servletContext); .buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2"); assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");

View File

@ -0,0 +1,78 @@
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.test.web.servlet.request;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityCustomizer;
import org.springframework.security.test.web.support.WebTestUtils;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.csrf.CookieCsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.springframework.test.context.web.WebAppConfiguration;
import org.springframework.web.context.WebApplicationContext;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@ExtendWith(SpringExtension.class)
@ContextConfiguration
@WebAppConfiguration
public class SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests {
@Autowired
private WebApplicationContext wac;
// SEC-3836
@Test
public void findCookieCsrfTokenRepository() {
MockHttpServletRequest request = post("/").buildRequest(this.wac.getServletContext());
CsrfTokenRepository csrfTokenRepository = WebTestUtils.getCsrfTokenRepository(request);
assertThat(csrfTokenRepository).isNotNull();
assertThat(csrfTokenRepository).isEqualTo(Config.cookieCsrfTokenRepository);
}
@Configuration
@EnableWebSecurity
static class Config {
static CsrfTokenRepository cookieCsrfTokenRepository = new CookieCsrfTokenRepository();
@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
http.csrf().csrfTokenRepository(cookieCsrfTokenRepository);
return http.build();
}
@Bean
WebSecurityCustomizer webSecurityCustomizer() {
// Enable the DebugFilter
return (web) -> web.debug(true);
}
}
}

View File

@ -39,7 +39,6 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.security.web.util.matcher.AnyRequestMatcher;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
@ -75,22 +74,22 @@ public class WebTestUtilsTests {
@Test @Test
public void getCsrfTokenRepositorytNoWac() { public void getCsrfTokenRepositorytNoWac() {
assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)) assertThat(WebTestUtils.getCsrfTokenRepository(this.request))
.isInstanceOf(CsrfTokenRepositoryRequestHandler.class); .isInstanceOf(HttpSessionCsrfTokenRepository.class);
} }
@Test @Test
public void getCsrfTokenRepositorytNoSecurity() { public void getCsrfTokenRepositorytNoSecurity() {
loadConfig(Config.class); loadConfig(Config.class);
assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)) assertThat(WebTestUtils.getCsrfTokenRepository(this.request))
.isInstanceOf(CsrfTokenRepositoryRequestHandler.class); .isInstanceOf(HttpSessionCsrfTokenRepository.class);
} }
@Test @Test
public void getCsrfTokenRepositorytSecurityNoCsrf() { public void getCsrfTokenRepositorytSecurityNoCsrf() {
loadConfig(SecurityNoCsrfConfig.class); loadConfig(SecurityNoCsrfConfig.class);
assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)) assertThat(WebTestUtils.getCsrfTokenRepository(this.request))
.isInstanceOf(CsrfTokenRepositoryRequestHandler.class); .isInstanceOf(HttpSessionCsrfTokenRepository.class);
} }
@Test @Test
@ -98,7 +97,7 @@ public class WebTestUtilsTests {
CustomSecurityConfig.CONTEXT_REPO = this.contextRepo; CustomSecurityConfig.CONTEXT_REPO = this.contextRepo;
CustomSecurityConfig.CSRF_REPO = this.csrfRepo; CustomSecurityConfig.CSRF_REPO = this.csrfRepo;
loadConfig(CustomSecurityConfig.class); loadConfig(CustomSecurityConfig.class);
// assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo); assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo);
} }
// getSecurityContextRepository // getSecurityContextRepository

View File

@ -38,17 +38,17 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
private final Log logger = LogFactory.getLog(getClass()); private final Log logger = LogFactory.getLog(getClass());
private final CsrfTokenRepository csrfTokenRepository; private final CsrfTokenRepository tokenRepository;
private CsrfTokenRequestHandler requestHandler; private CsrfTokenRequestHandler requestHandler = new CsrfTokenRequestAttributeHandler();
/** /**
* Creates a new instance * Creates a new instance
* @param csrfTokenRepository the {@link CsrfTokenRepository} to use * @param tokenRepository the {@link CsrfTokenRepository} to use
*/ */
public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) { public CsrfAuthenticationStrategy(CsrfTokenRepository tokenRepository) {
this.requestHandler = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository); Assert.notNull(tokenRepository, "tokenRepository cannot be null");
this.csrfTokenRepository = csrfTokenRepository; this.tokenRepository = tokenRepository;
} }
/** /**
@ -64,8 +64,9 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
@Override @Override
public void onAuthentication(Authentication authentication, HttpServletRequest request, public void onAuthentication(Authentication authentication, HttpServletRequest request,
HttpServletResponse response) throws SessionAuthenticationException { HttpServletResponse response) throws SessionAuthenticationException {
this.csrfTokenRepository.saveToken(null, request, response); this.tokenRepository.saveToken(null, request, response);
this.requestHandler.handle(request, response); DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response);
this.requestHandler.handle(request, response, deferredCsrfToken::get);
this.logger.debug("Replaced CSRF Token"); this.logger.debug("Replaced CSRF Token");
} }

View File

@ -81,30 +81,21 @@ public final class CsrfFilter extends OncePerRequestFilter {
private final Log logger = LogFactory.getLog(getClass()); private final Log logger = LogFactory.getLog(getClass());
private final CsrfTokenRequestHandler requestHandler; private final CsrfTokenRepository tokenRepository;
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER; private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
/** private CsrfTokenRequestHandler requestHandler = new CsrfTokenRequestAttributeHandler();
* Creates a new instance.
* @param csrfTokenRepository the {@link CsrfTokenRepository} to use
* @deprecated Use {@link CsrfFilter#CsrfFilter(CsrfTokenRequestHandler)} instead
*/
@Deprecated
public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
this(new CsrfTokenRepositoryRequestHandler(csrfTokenRepository));
}
/** /**
* Creates a new instance. * Creates a new instance.
* @param requestHandler the {@link CsrfTokenRequestHandler} to use. Default is * @param tokenRepository the {@link CsrfTokenRepository} to use
* {@link CsrfTokenRepositoryRequestHandler}.
*/ */
public CsrfFilter(CsrfTokenRequestHandler requestHandler) { public CsrfFilter(CsrfTokenRepository tokenRepository) {
Assert.notNull(requestHandler, "requestHandler cannot be null"); Assert.notNull(tokenRepository, "tokenRepository cannot be null");
this.requestHandler = requestHandler; this.tokenRepository = tokenRepository;
} }
@Override @Override
@ -115,7 +106,8 @@ public final class CsrfFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
DeferredCsrfToken deferredCsrfToken = this.requestHandler.handle(request, response); DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response);
this.requestHandler.handle(request, response, deferredCsrfToken::get);
if (!this.requireCsrfProtectionMatcher.matches(request)) { if (!this.requireCsrfProtectionMatcher.matches(request)) {
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Did not protect against CSRF since request did not match " this.logger.trace("Did not protect against CSRF since request did not match "
@ -173,6 +165,21 @@ public final class CsrfFilter extends OncePerRequestFilter {
this.accessDeniedHandler = accessDeniedHandler; this.accessDeniedHandler = accessDeniedHandler;
} }
/**
* Specifies a {@link CsrfTokenRequestHandler} that is used to make the
* {@link CsrfToken} available as a request attribute.
*
* <p>
* The default is {@link CsrfTokenRequestAttributeHandler}.
* </p>
* @param requestHandler the {@link CsrfTokenRequestHandler} to use
* @since 5.8
*/
public void setRequestHandler(CsrfTokenRequestHandler requestHandler) {
Assert.notNull(requestHandler, "requestHandler cannot be null");
this.requestHandler = requestHandler;
}
/** /**
* Constant time comparison to prevent against timing attacks. * Constant time comparison to prevent against timing attacks.
* @param expected * @param expected

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ import jakarta.servlet.http.HttpSession;
* {@link HttpSession}. * {@link HttpSession}.
* *
* @author Rob Winch * @author Rob Winch
* @author Steve Riesenberg
* @since 3.2 * @since 3.2
* @see HttpSessionCsrfTokenRepository * @see HttpSessionCsrfTokenRepository
*/ */
@ -55,4 +56,20 @@ public interface CsrfTokenRepository {
*/ */
CsrfToken loadToken(HttpServletRequest request); CsrfToken loadToken(HttpServletRequest request);
/**
* Defers loading the {@link CsrfToken} using the {@link HttpServletRequest} and
* {@link HttpServletResponse} until it is needed by the application.
* <p>
* The returned {@link DeferredCsrfToken} is cached to allow subsequent calls to
* {@link DeferredCsrfToken#get()} to return the same {@link CsrfToken} without the
* cost of loading or generating the token again.
* @param request the {@link HttpServletRequest} to use
* @param response the {@link HttpServletResponse} to use
* @return a {@link DeferredCsrfToken} that will load the {@link CsrfToken}
* @since 5.8
*/
default DeferredCsrfToken loadDeferredToken(HttpServletRequest request, HttpServletResponse response) {
return new RepositoryDeferredCsrfToken(this, request, response);
}
} }

View File

@ -31,29 +31,10 @@ import org.springframework.util.Assert;
* @author Steve Riesenberg * @author Steve Riesenberg
* @since 5.8 * @since 5.8
*/ */
public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandler { public class CsrfTokenRequestAttributeHandler implements CsrfTokenRequestHandler {
private final CsrfTokenRepository csrfTokenRepository;
private String csrfRequestAttributeName = "_csrf"; private String csrfRequestAttributeName = "_csrf";
/**
* Creates a new instance.
*/
public CsrfTokenRepositoryRequestHandler() {
this(new HttpSessionCsrfTokenRepository());
}
/**
* Creates a new instance.
* @param csrfTokenRepository the {@link CsrfTokenRepository} to use. Default
* {@link HttpSessionCsrfTokenRepository}
*/
public CsrfTokenRepositoryRequestHandler(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
this.csrfTokenRepository = csrfTokenRepository;
}
/** /**
* The {@link CsrfToken} is available as a request attribute named * The {@link CsrfToken} is available as a request attribute named
* {@code CsrfToken.class.getName()}. By default, an additional request attribute that * {@code CsrfToken.class.getName()}. By default, an additional request attribute that
@ -67,18 +48,18 @@ public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandle
} }
@Override @Override
public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) { public void handle(HttpServletRequest request, HttpServletResponse response,
Supplier<CsrfToken> deferredCsrfToken) {
Assert.notNull(request, "request cannot be null"); Assert.notNull(request, "request cannot be null");
Assert.notNull(response, "response cannot be null"); Assert.notNull(response, "response cannot be null");
Assert.notNull(deferredCsrfToken, "deferredCsrfToken cannot be null");
request.setAttribute(HttpServletResponse.class.getName(), response); request.setAttribute(HttpServletResponse.class.getName(), response);
DeferredCsrfToken deferredCsrfToken = new RepositoryDeferredCsrfToken(request, response); CsrfToken csrfToken = new SupplierCsrfToken(deferredCsrfToken);
CsrfToken csrfToken = new SupplierCsrfToken(deferredCsrfToken::get);
request.setAttribute(CsrfToken.class.getName(), csrfToken); request.setAttribute(CsrfToken.class.getName(), csrfToken);
String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName
: csrfToken.getParameterName(); : csrfToken.getParameterName();
request.setAttribute(csrfAttrName, csrfToken); request.setAttribute(csrfAttrName, csrfToken);
return deferredCsrfToken;
} }
private static final class SupplierCsrfToken implements CsrfToken { private static final class SupplierCsrfToken implements CsrfToken {
@ -114,46 +95,4 @@ public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandle
} }
private final class RepositoryDeferredCsrfToken implements DeferredCsrfToken {
private final HttpServletRequest request;
private final HttpServletResponse response;
private CsrfToken csrfToken;
private Boolean missingToken;
RepositoryDeferredCsrfToken(HttpServletRequest request, HttpServletResponse response) {
this.request = request;
this.response = response;
}
@Override
public CsrfToken get() {
init();
return this.csrfToken;
}
@Override
public boolean isGenerated() {
init();
return this.missingToken;
}
private void init() {
if (this.csrfToken != null) {
return;
}
this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.loadToken(this.request);
this.missingToken = (this.csrfToken == null);
if (this.missingToken) {
this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.generateToken(this.request);
CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.saveToken(this.csrfToken, this.request,
this.response);
}
}
}
} }

View File

@ -16,20 +16,22 @@
package org.springframework.security.web.csrf; package org.springframework.security.web.csrf;
import java.util.function.Supplier;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
* An interface that is used to determine the {@link CsrfToken} to use and make the * A callback interface that is used to make the {@link CsrfToken} created by the
* {@link CsrfToken} available as a request attribute. Implementations of this interface * {@link CsrfTokenRepository} available as a request attribute. Implementations of this
* may choose to perform additional tasks or customize how the token is made available to * interface may choose to perform additional tasks or customize how the token is made
* the application through request attributes. * available to the application through request attributes.
* *
* @author Steve Riesenberg * @author Steve Riesenberg
* @since 5.8 * @since 5.8
* @see CsrfTokenRepositoryRequestHandler * @see CsrfTokenRequestAttributeHandler
*/ */
@FunctionalInterface @FunctionalInterface
public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver { public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver {
@ -38,8 +40,9 @@ public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver {
* Handles a request using a {@link CsrfToken}. * Handles a request using a {@link CsrfToken}.
* @param request the {@code HttpServletRequest} being handled * @param request the {@code HttpServletRequest} being handled
* @param response the {@code HttpServletResponse} being handled * @param response the {@code HttpServletResponse} being handled
* @param csrfToken the {@link CsrfToken} created by the {@link CsrfTokenRepository}
*/ */
DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response); void handle(HttpServletRequest request, HttpServletResponse response, Supplier<CsrfToken> csrfToken);
@Override @Override
default String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) { default String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {

View File

@ -25,7 +25,7 @@ import jakarta.servlet.http.HttpServletRequest;
* *
* @author Steve Riesenberg * @author Steve Riesenberg
* @since 5.8 * @since 5.8
* @see CsrfTokenRepositoryRequestHandler * @see CsrfTokenRequestAttributeHandler
*/ */
@FunctionalInterface @FunctionalInterface
public interface CsrfTokenRequestResolver { public interface CsrfTokenRequestResolver {

View File

@ -20,11 +20,12 @@ package org.springframework.security.web.csrf;
* An interface that allows delayed access to a {@link CsrfToken} that may be generated. * An interface that allows delayed access to a {@link CsrfToken} that may be generated.
* *
* @author Rob Winch * @author Rob Winch
* @author Steve Riesenberg
* @since 5.8 * @since 5.8
*/ */
public interface DeferredCsrfToken { public interface DeferredCsrfToken {
/*** /**
* Gets the {@link CsrfToken} * Gets the {@link CsrfToken}
* @return a non-null {@link CsrfToken} * @return a non-null {@link CsrfToken}
*/ */

View File

@ -27,8 +27,9 @@ import org.springframework.util.Assert;
* *
* @author Rob Winch * @author Rob Winch
* @since 4.1 * @since 4.1
* @deprecated Use org.springframework.security.web.csrf.CsrfTokenRequestHandler which * @deprecated Use
* returns a {@link DeferredCsrfToken} * {@link CsrfTokenRepository#loadDeferredToken(HttpServletRequest, HttpServletResponse)}
* which returns a {@link DeferredCsrfToken}
*/ */
@Deprecated @Deprecated
public final class LazyCsrfTokenRepository implements CsrfTokenRepository { public final class LazyCsrfTokenRepository implements CsrfTokenRepository {

View File

@ -0,0 +1,71 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
/**
* @author Rob Winch
* @author Steve Riesenberg
* @since 5.8
*/
final class RepositoryDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfTokenRepository csrfTokenRepository;
private final HttpServletRequest request;
private final HttpServletResponse response;
private CsrfToken csrfToken;
private boolean missingToken;
RepositoryDeferredCsrfToken(CsrfTokenRepository csrfTokenRepository, HttpServletRequest request,
HttpServletResponse response) {
this.csrfTokenRepository = csrfTokenRepository;
this.request = request;
this.response = response;
}
@Override
public CsrfToken get() {
init();
return this.csrfToken;
}
@Override
public boolean isGenerated() {
init();
return this.missingToken;
}
private void init() {
if (this.csrfToken != null) {
return;
}
this.csrfToken = this.csrfTokenRepository.loadToken(this.request);
this.missingToken = (this.csrfToken == null);
if (this.missingToken) {
this.csrfToken = this.csrfTokenRepository.generateToken(this.request);
this.csrfTokenRepository.saveToken(this.csrfToken, this.request, this.response);
}
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -25,6 +25,7 @@ import org.springframework.mock.web.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
/** /**
* @author Rob Winch * @author Rob Winch
@ -245,6 +246,33 @@ public class CookieCsrfTokenRepositoryTests {
assertThat(loadToken.getToken()).isEqualTo(value); assertThat(loadToken.getToken()).isEqualTo(value);
} }
@Test
public void loadDeferredTokenWhenDoesNotExistThenGeneratedAndSaved() {
DeferredCsrfToken deferredCsrfToken = this.repository.loadDeferredToken(this.request, this.response);
CsrfToken csrfToken = deferredCsrfToken.get();
assertThat(csrfToken).isNotNull();
assertThat(deferredCsrfToken.isGenerated()).isTrue();
Cookie tokenCookie = this.response.getCookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME);
assertThat(tokenCookie).isNotNull();
assertThat(tokenCookie.getMaxAge()).isEqualTo(-1);
assertThat(tokenCookie.getName()).isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME);
assertThat(tokenCookie.getPath()).isEqualTo(this.request.getContextPath());
assertThat(tokenCookie.getSecure()).isEqualTo(this.request.isSecure());
assertThat(tokenCookie.getValue()).isEqualTo(csrfToken.getToken());
assertThat(tokenCookie.isHttpOnly()).isEqualTo(true);
}
@Test
public void loadDeferredTokenWhenExistsThenLoaded() {
CsrfToken generatedToken = this.repository.generateToken(this.request);
this.request
.setCookies(new Cookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME, generatedToken.getToken()));
DeferredCsrfToken deferredCsrfToken = this.repository.loadDeferredToken(this.request, this.response);
CsrfToken csrfToken = deferredCsrfToken.get();
assertThatCsrfToken(csrfToken).isEqualTo(generatedToken);
assertThat(deferredCsrfToken.isGenerated()).isFalse();
}
@Test @Test
public void setCookieNameNullIllegalArgumentException() { public void setCookieNameNullIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setCookieName(null)); assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setCookieName(null));

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -32,7 +32,6 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -82,31 +81,31 @@ public class CsrfAuthenticationStrategyTests {
@Test @Test
public void onAuthenticationWhenCustomRequestHandlerThenUsed() { public void onAuthenticationWhenCustomRequestHandlerThenUsed() {
given(this.csrfTokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.existingToken, false));
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
this.strategy.setRequestHandler(requestHandler); this.strategy.setRequestHandler(requestHandler);
this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
this.response); this.response);
verify(requestHandler).handle(eq(this.request), eq(this.response)); verify(requestHandler).handle(eq(this.request), eq(this.response), any());
verifyNoMoreInteractions(requestHandler); verifyNoMoreInteractions(requestHandler);
} }
@Test @Test
public void logoutRemovesCsrfTokenAndSavesNew() { public void logoutRemovesCsrfTokenAndLoadsNewDeferredCsrfToken() {
given(this.csrfTokenRepository.loadToken(this.request)).willReturn(null, this.existingToken); given(this.csrfTokenRepository.loadDeferredToken(this.request, this.response))
given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); .willReturn(new TestDeferredCsrfToken(this.generatedToken, false));
this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
this.response); this.response);
verify(this.csrfTokenRepository).saveToken(null, this.request, this.response);
verify(this.csrfTokenRepository).loadDeferredToken(this.request, this.response);
// SEC-2404, SEC-2832 // SEC-2404, SEC-2832
CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
assertThat(tokenInRequest.getToken()).isSameAs(this.generatedToken.getToken()); assertThat(tokenInRequest.getToken()).isSameAs(this.generatedToken.getToken());
assertThat(tokenInRequest.getHeaderName()).isSameAs(this.generatedToken.getHeaderName()); assertThat(tokenInRequest.getHeaderName()).isSameAs(this.generatedToken.getHeaderName());
assertThat(tokenInRequest.getParameterName()).isSameAs(this.generatedToken.getParameterName()); assertThat(tokenInRequest.getParameterName()).isSameAs(this.generatedToken.getParameterName());
assertThat(this.request.getAttribute(this.generatedToken.getParameterName())).isSameAs(tokenInRequest); assertThat(this.request.getAttribute(this.generatedToken.getParameterName())).isSameAs(tokenInRequest);
// verify after the test accesses the CsrfToken which causes the lazy save to
// occur
verify(this.csrfTokenRepository).saveToken(null, this.request, this.response);
verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class),
any(HttpServletResponse.class));
} }
// SEC-2872 // SEC-2872
@ -121,16 +120,10 @@ public class CsrfAuthenticationStrategyTests {
any(HttpServletResponse.class)); any(HttpServletResponse.class));
CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
tokenInRequest.getToken(); tokenInRequest.getToken();
verify(this.csrfTokenRepository).loadToken(this.request);
verify(this.csrfTokenRepository).generateToken(this.request);
verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class), verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class),
any(HttpServletResponse.class)); any(HttpServletResponse.class));
} }
@Test
public void logoutWhenNoCsrfToken() {
this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
this.response);
verify(this.csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class),
any(HttpServletResponse.class));
}
} }

View File

@ -43,7 +43,6 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
@ -87,11 +86,7 @@ public class CsrfFilterTests {
} }
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) { private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository)); CsrfFilter filter = new CsrfFilter(repository);
}
private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) {
CsrfFilter filter = new CsrfFilter(requestHandler);
filter.setRequireCsrfProtectionMatcher(this.requestMatcher); filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
filter.setAccessDeniedHandler(this.deniedHandler); filter.setAccessDeniedHandler(this.deniedHandler);
return filter; return filter;
@ -104,7 +99,7 @@ public class CsrfFilterTests {
@Test @Test
public void constructorNullRepository() { public void constructorNullRepository() {
assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null)); assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null));
} }
// SEC-2276 // SEC-2276
@ -129,7 +124,8 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
@ -140,7 +136,8 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
@ -152,7 +149,8 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
@ -165,7 +163,8 @@ public class CsrfFilterTests {
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
throws ServletException, IOException { throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
@ -178,7 +177,8 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(false); given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
@ -189,7 +189,8 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(false); given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, true));
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
@ -200,7 +201,8 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
@ -213,7 +215,8 @@ public class CsrfFilterTests {
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
throws ServletException, IOException { throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
@ -226,7 +229,8 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
@ -240,7 +244,8 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, true));
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
@ -248,16 +253,17 @@ public class CsrfFilterTests {
// LazyCsrfTokenRepository requires the response as an attribute // LazyCsrfTokenRepository requires the response as an attribute
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
verify(this.filterChain).doFilter(this.request, this.response); verify(this.filterChain).doFilter(this.request, this.response);
verify(this.tokenRepository).saveToken(this.token, this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler); verifyNoMoreInteractions(this.deniedHandler);
} }
@Test @Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
this.filter = createCsrfFilter(this.tokenRepository); this.filter = new CsrfFilter(this.tokenRepository);
this.filter.setAccessDeniedHandler(this.deniedHandler); this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
resetRequestResponse(); resetRequestResponse();
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.request.setMethod(method); this.request.setMethod(method);
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
verify(this.filterChain).doFilter(this.request, this.response); verify(this.filterChain).doFilter(this.request, this.response);
@ -273,11 +279,12 @@ public class CsrfFilterTests {
*/ */
@Test @Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception { public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception {
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); this.filter = new CsrfFilter(this.tokenRepository);
this.filter.setAccessDeniedHandler(this.deniedHandler); this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) { for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
resetRequestResponse(); resetRequestResponse();
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.request.setMethod(method); this.request.setMethod(method);
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
@ -288,11 +295,12 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); this.filter = new CsrfFilter(this.tokenRepository);
this.filter.setAccessDeniedHandler(this.deniedHandler); this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) { for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
resetRequestResponse(); resetRequestResponse();
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.request.setMethod(method); this.request.setMethod(method);
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
@ -303,10 +311,11 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterDefaultAccessDenied() throws ServletException, IOException { public void doFilterDefaultAccessDenied() throws ServletException, IOException {
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); this.filter = new CsrfFilter(this.tokenRepository);
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
@ -317,7 +326,7 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception { public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception {
CsrfTokenRepository repository = mock(CsrfTokenRepository.class); CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
CsrfFilter filter = createCsrfFilter(repository); CsrfFilter filter = new CsrfFilter(repository);
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token); lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
CsrfFilter.skipRequest(request); CsrfFilter.skipRequest(request);
@ -333,7 +342,8 @@ public class CsrfFilterTests {
given(token.getToken()).willReturn(null); given(token.getToken()).willReturn(null);
given(token.getHeaderName()).willReturn(this.token.getHeaderName()); given(token.getHeaderName()).willReturn(this.token.getHeaderName());
given(token.getParameterName()).willReturn(this.token.getParameterName()); given(token.getParameterName()).willReturn(this.token.getParameterName());
given(this.tokenRepository.loadToken(this.request)).willReturn(token); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(token, false));
given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.requestMatcher.matches(this.request)).willReturn(true);
filter.doFilterInternal(this.request, this.response, this.filterChain); filter.doFilterInternal(this.request, this.response, this.filterChain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
@ -341,13 +351,15 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterWhenRequestHandlerThenUsed() throws Exception { public void doFilterWhenRequestHandlerThenUsed() throws Exception {
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
given(requestHandler.handle(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false)); .willReturn(new TestDeferredCsrfToken(this.token, false));
this.filter = createCsrfFilter(requestHandler); CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
this.filter = createCsrfFilter(this.tokenRepository);
this.filter.setRequestHandler(requestHandler);
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain); this.filter.doFilter(this.request, this.response, this.filterChain);
verify(requestHandler).handle(eq(this.request), eq(this.response)); verify(this.tokenRepository).loadDeferredToken(this.request, this.response);
verify(requestHandler).handle(eq(this.request), eq(this.response), any());
verify(this.filterChain).doFilter(this.request, this.response); verify(this.filterChain).doFilter(this.request, this.response);
} }
@ -365,41 +377,20 @@ public class CsrfFilterTests {
@Test @Test
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet() public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
throws ServletException, IOException { throws ServletException, IOException {
CsrfFilter filter = createCsrfFilter(this.tokenRepository);
String csrfAttrName = "_csrf"; String csrfAttrName = "_csrf";
CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository); CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler();
requestHandler.setCsrfRequestAttributeName(csrfAttrName); requestHandler.setCsrfRequestAttributeName(csrfAttrName);
this.filter = createCsrfFilter(requestHandler); filter.setRequestHandler(requestHandler);
CsrfToken expectedCsrfToken = spy(this.token); CsrfToken expectedCsrfToken = mock(CsrfToken.class);
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken); given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true));
this.filter.doFilter(this.request, this.response, this.filterChain); filter.doFilter(this.request, this.response, this.filterChain);
verifyNoInteractions(expectedCsrfToken); verifyNoInteractions(expectedCsrfToken);
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken); assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken);
} }
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken;
private final boolean isGenerated;
private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) {
this.csrfToken = csrfToken;
this.isGenerated = isGenerated;
}
@Override
public CsrfToken get() {
return this.csrfToken;
}
@Override
public boolean isGenerated() {
return this.isGenerated;
}
}
} }

View File

@ -18,29 +18,22 @@ package org.springframework.security.web.csrf;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given; import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
/** /**
* Tests for {@link CsrfTokenRepositoryRequestHandler}. * Tests for {@link CsrfTokenRequestAttributeHandler}.
* *
* @author Steve Riesenberg * @author Steve Riesenberg
* @since 5.8 * @since 5.8
*/ */
@ExtendWith(MockitoExtension.class) public class CsrfTokenRequestAttributeHandlerTests {
public class CsrfTokenRepositoryRequestHandlerTests {
@Mock
CsrfTokenRepository tokenRepository;
private MockHttpServletRequest request; private MockHttpServletRequest request;
@ -48,76 +41,73 @@ public class CsrfTokenRepositoryRequestHandlerTests {
private CsrfToken token; private CsrfToken token;
private CsrfTokenRepositoryRequestHandler handler; private CsrfTokenRequestAttributeHandler handler;
@BeforeEach @BeforeEach
public void setup() { public void setup() {
this.request = new MockHttpServletRequest(); this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue");
this.handler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository); this.handler = new CsrfTokenRequestAttributeHandler();
}
@Test
public void constructorWhenCsrfTokenRepositoryIsNullThenThrowsIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> new CsrfTokenRepositoryRequestHandler(null))
.withMessage("csrfTokenRepository cannot be null");
// @formatter:on
} }
@Test @Test
public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() { public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.handle(null, this.response)) .isThrownBy(() -> this.handler.handle(null, this.response, () -> this.token))
.withMessage("request cannot be null"); .withMessage("request cannot be null");
// @formatter:on
} }
@Test @Test
public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() { public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() {
// @formatter:off // @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.handle(this.request, null)) .isThrownBy(() -> this.handler.handle(this.request, null, () -> this.token))
.withMessage("response cannot be null"); .withMessage("response cannot be null");
// @formatter:on // @formatter:on
} }
@Test
public void handleWhenCsrfTokenSupplierIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.handler.handle(this.request, this.response, null))
.withMessage("deferredCsrfToken cannot be null");
}
@Test
public void handleWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
// @formatter:off
this.handler.setCsrfRequestAttributeName(null);
assertThatIllegalStateException()
.isThrownBy(() -> this.handler.handle(this.request, this.response, () -> null))
.withMessage("csrfTokenSupplier returned null delegate");
// @formatter:on
}
@Test @Test
public void handleWhenCsrfRequestAttributeSetThenUsed() { public void handleWhenCsrfRequestAttributeSetThenUsed() {
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
this.handler.setCsrfRequestAttributeName("_csrf"); this.handler.setCsrfRequestAttributeName("_csrf");
this.handler.handle(this.request, this.response); this.handler.handle(this.request, this.response, () -> this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token);
} }
@Test @Test
public void handleWhenValidParametersThenRequestAttributesSet() { public void handleWhenValidParametersThenRequestAttributesSet() {
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.handler.handle(this.request, this.response, () -> this.token);
this.handler.handle(this.request, this.response);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token);
} }
@Test @Test
public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() { public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() {
// @formatter:off assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))
.withMessage("request cannot be null"); .withMessage("request cannot be null");
// @formatter:on
} }
@Test @Test
public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() { public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
// @formatter:off assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null))
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null))
.withMessage("csrfToken cannot be null"); .withMessage("csrfToken cannot be null");
// @formatter:on
} }
@Test @Test

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@ import org.springframework.mock.web.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
/** /**
* @author Rob Winch * @author Rob Winch
@ -85,6 +86,26 @@ public class HttpSessionCsrfTokenRepositoryTests {
assertThat(this.repo.loadToken(this.request)).isNull(); assertThat(this.repo.loadToken(this.request)).isNull();
} }
@Test
public void loadDeferredTokenWhenDoesNotExistThenGeneratedAndSaved() {
DeferredCsrfToken deferredCsrfToken = this.repo.loadDeferredToken(this.request, this.response);
CsrfToken csrfToken = deferredCsrfToken.get();
assertThat(csrfToken).isNotNull();
assertThat(deferredCsrfToken.isGenerated()).isTrue();
String attrName = this.request.getSession().getAttributeNames().nextElement();
assertThatCsrfToken(this.request.getSession().getAttribute(attrName)).isEqualTo(csrfToken);
}
@Test
public void loadDeferredTokenWhenExistsThenLoaded() {
CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def");
this.repo.saveToken(tokenToSave, this.request, this.response);
DeferredCsrfToken deferredCsrfToken = this.repo.loadDeferredToken(this.request, this.response);
CsrfToken csrfToken = deferredCsrfToken.get();
assertThatCsrfToken(csrfToken).isEqualTo(tokenToSave);
assertThat(deferredCsrfToken.isGenerated()).isFalse();
}
@Test @Test
public void saveToken() { public void saveToken() {
CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def"); CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def");

View File

@ -0,0 +1,40 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken;
private final boolean isGenerated;
TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) {
this.csrfToken = csrfToken;
this.isGenerated = isGenerated;
}
@Override
public CsrfToken get() {
return this.csrfToken;
}
@Override
public boolean isGenerated() {
return this.isGenerated;
}
}