diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index f26f910185..0bbf004c7f 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -3738,7 +3738,8 @@ public class ServerHttpSecurity { */ public final class LogoutSpec { private LogoutWebFilter logoutWebFilter = new LogoutWebFilter(); - private List logoutHandlers = new ArrayList<>(Arrays.asList(new SecurityContextServerLogoutHandler())); + private final SecurityContextServerLogoutHandler DEFAULT_LOGOUT_HANDLER = new SecurityContextServerLogoutHandler(); + private List logoutHandlers = new ArrayList<>(Arrays.asList(this.DEFAULT_LOGOUT_HANDLER)); /** * Configures the logout handler. Default is {@code SecurityContextServerLogoutHandler} @@ -3802,6 +3803,10 @@ public class ServerHttpSecurity { } private ServerLogoutHandler createLogoutHandler() { + ServerSecurityContextRepository securityContextRepository = ServerHttpSecurity.this.securityContextRepository; + if (securityContextRepository != null) { + this.DEFAULT_LOGOUT_HANDLER.setSecurityContextRepository(securityContextRepository); + } if (this.logoutHandlers.isEmpty()) { return null; } else if (this.logoutHandlers.size() == 1) { diff --git a/config/src/test/java/org/springframework/security/config/web/server/LogoutSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/LogoutSpecTests.java index e417a3cda5..723251e4cf 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/LogoutSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/LogoutSpecTests.java @@ -21,6 +21,7 @@ import org.openqa.selenium.WebDriver; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder; import org.springframework.security.web.server.SecurityWebFilterChain; +import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; @@ -200,4 +201,46 @@ public class LogoutSpecTests { homePage .assertAt(); } + + + @Test + public void logoutWhenCustomSecurityContextRepositoryThenLogsOut() { + WebSessionServerSecurityContextRepository repository = new WebSessionServerSecurityContextRepository(); + repository.setSpringSecurityContextAttrName("CUSTOM_CONTEXT_ATTR"); + SecurityWebFilterChain securityWebFilter = this.http + .securityContextRepository(repository) + .authorizeExchange() + .anyExchange().authenticated() + .and() + .formLogin() + .and() + .logout() + .and() + .build(); + + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(securityWebFilter) + .build(); + + WebDriver driver = WebTestClientHtmlUnitDriverBuilder + .webTestClientSetup(webTestClient) + .build(); + + FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage.to(driver, FormLoginTests.DefaultLoginPage.class) + .assertAt(); + + FormLoginTests.HomePage homePage = loginPage.loginForm() + .username("user") + .password("password") + .submit(FormLoginTests.HomePage.class); + + homePage.assertAt(); + + FormLoginTests.DefaultLogoutPage.to(driver) + .assertAt() + .logout(); + + FormLoginTests.HomePage.to(driver, FormLoginTests.DefaultLoginPage.class) + .assertAt(); + } } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index c820748e57..4723da1806 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -83,6 +83,7 @@ import org.springframework.security.web.server.authentication.RedirectServerAuth import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; +import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler; import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; @@ -716,6 +717,8 @@ public class OAuth2LoginTests { http .csrf().disable() .logout() + // avoid using mock ServerSecurityContextRepository for logout + .logoutHandler(new SecurityContextServerLogoutHandler()) .logoutSuccessHandler( new OidcClientInitiatedServerLogoutSuccessHandler( new InMemoryReactiveClientRegistrationRepository(this.withLogout)))