SessionManagementConfigurer properly defaults SecurityContextRepository

Previously the default was an HttpSessionSecurityContextRepository which
meant that if a stateless authentication occurred the SecurityContext would
be lost on ERROR dispatch.

This commit ensures that the RequestAttributeSecurityContextRepository is
also consulted by default.

Closes gh-12070
This commit is contained in:
Rob Winch 2022-10-19 22:40:31 -05:00
parent a4858d9eaa
commit 9cb668aec2
3 changed files with 109 additions and 3 deletions

View File

@ -48,7 +48,9 @@ import org.springframework.security.web.authentication.session.NullAuthenticated
import org.springframework.security.web.authentication.session.RegisterSessionAuthenticationStrategy;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.authentication.session.SessionFixationProtectionStrategy;
import org.springframework.security.web.context.DelegatingSecurityContextRepository;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.NullSecurityContextRepository;
import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.savedrequest.NullRequestCache;
@ -141,6 +143,12 @@ public final class SessionManagementConfigurer<H extends HttpSecurityBuilder<H>>
private Boolean requireExplicitAuthenticationStrategy;
/**
* This should not use RequestAttributeSecurityContextRepository since that is
* stateless and sesison management is about state management.
*/
private SecurityContextRepository sessionManagementSecurityContextRepository = new HttpSessionSecurityContextRepository();
/**
* Creates a new instance
* @see HttpSecurity#sessionManagement()
@ -356,6 +364,7 @@ public final class SessionManagementConfigurer<H extends HttpSecurityBuilder<H>>
if (securityContextRepository == null) {
if (stateless) {
http.setSharedObject(SecurityContextRepository.class, new RequestAttributeSecurityContextRepository());
this.sessionManagementSecurityContextRepository = new NullSecurityContextRepository();
}
else {
HttpSessionSecurityContextRepository httpSecurityRepository = new HttpSessionSecurityContextRepository();
@ -365,7 +374,10 @@ public final class SessionManagementConfigurer<H extends HttpSecurityBuilder<H>>
if (trustResolver != null) {
httpSecurityRepository.setTrustResolver(trustResolver);
}
http.setSharedObject(SecurityContextRepository.class, httpSecurityRepository);
this.sessionManagementSecurityContextRepository = httpSecurityRepository;
DelegatingSecurityContextRepository defaultRepository = new DelegatingSecurityContextRepository(
httpSecurityRepository, new RequestAttributeSecurityContextRepository());
http.setSharedObject(SecurityContextRepository.class, defaultRepository);
}
}
RequestCache requestCache = http.getSharedObject(RequestCache.class);
@ -420,7 +432,7 @@ public final class SessionManagementConfigurer<H extends HttpSecurityBuilder<H>>
if (shouldRequireExplicitAuthenticationStrategy()) {
return null;
}
SecurityContextRepository securityContextRepository = http.getSharedObject(SecurityContextRepository.class);
SecurityContextRepository securityContextRepository = this.sessionManagementSecurityContextRepository;
SessionManagementFilter sessionManagementFilter = new SessionManagementFilter(securityContextRepository,
getSessionAuthenticationStrategy(http));
if (this.sessionAuthenticationErrorUrl != null) {

View File

@ -16,6 +16,14 @@
package org.springframework.security.config.annotation.web.configurers;
import java.io.IOException;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
@ -25,8 +33,13 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.authentication.AuthenticationTrustResolver;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.TestDeferredSecurityContext;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
@ -55,8 +68,10 @@ import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.util.WebUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
@ -360,6 +375,84 @@ public class SessionManagementConfigurerTests {
assertThat(securityContext).isNotNull();
}
/**
* This ensures that if an ErrorDispatch occurs, then the SecurityContextRepository
* defaulted by SessionManagementConfigurer is correct (looks at both Session and
* Request Attributes).
* @throws Exception
*/
@Test
public void gh12070WhenErrorDispatchSecurityContextRepositoryWorks() throws Exception {
Filter errorDispatchFilter = new Filter() {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
try {
chain.doFilter(request, response);
}
catch (ServletException ex) {
if (request.getDispatcherType() == DispatcherType.ERROR) {
throw ex;
}
MockHttpServletRequest httpRequest = WebUtils.getNativeRequest(request,
MockHttpServletRequest.class);
httpRequest.setDispatcherType(DispatcherType.ERROR);
// necessary to prevent HttpBasicFilter from invoking again
httpRequest.setAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE, "/error");
httpRequest.setRequestURI("/error");
MockFilterChain mockChain = (MockFilterChain) chain;
mockChain.reset();
mockChain.doFilter(httpRequest, response);
}
}
};
this.spring.addFilter(errorDispatchFilter).register(Gh12070IssueConfig.class).autowire();
// @formatter:off
this.mvc.perform(get("/500").with(httpBasic("user", "password")))
.andExpect(status().isInternalServerError());
// @formatter:on
}
@Configuration
@EnableWebSecurity
static class Gh12070IssueConfig {
@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeHttpRequests((authorize) -> authorize
.anyRequest().authenticated()
)
.httpBasic(Customizer.withDefaults())
.formLogin(Customizer.withDefaults());
return http.build();
// @formatter:on
}
@Bean
UserDetailsService userDetailsService() {
return new InMemoryUserDetailsManager(PasswordEncodedUser.user());
}
@RestController
static class ErrorController {
@GetMapping("/500")
String error() throws ServletException {
throw new ServletException("Error");
}
@GetMapping("/error")
ResponseEntity<String> errorHandler() {
return new ResponseEntity<>("error", HttpStatus.INTERNAL_SERVER_ERROR);
}
}
}
@Configuration
@EnableWebSecurity
static class SessionManagementRequestCacheConfig {

View File

@ -33,6 +33,7 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.context.DelegatingSecurityContextRepository;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextHolderFilter;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
@ -119,7 +120,7 @@ public class WebTestUtilsTests {
public void getSecurityContextRepositorySecurityNoCsrf() {
loadConfig(SecurityNoCsrfConfig.class);
assertThat(WebTestUtils.getSecurityContextRepository(this.request))
.isInstanceOf(HttpSessionSecurityContextRepository.class);
.isInstanceOf(DelegatingSecurityContextRepository.class);
}
@Test