diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 16cf1c0e18..1f7510b4e4 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -87,6 +87,7 @@ import org.springframework.security.web.authentication.LoginUrlAuthenticationEnt import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.util.matcher.AndRequestMatcher; @@ -177,6 +178,8 @@ public final class OAuth2LoginConfigurer> private OAuth2AuthorizedClientRepository authorizedClientRepository; + private SecurityContextRepository securityContextRepository; + /** * Sets the repository of client registrations. * @param clientRegistrationRepository the repository of client registrations @@ -230,6 +233,17 @@ public final class OAuth2LoginConfigurer> return this; } + /** + * Sets the {@link SecurityContextRepository} to use. + * @param securityContextRepository the {@link SecurityContextRepository} to use + * @return the {@link OAuth2LoginConfigurer} for further configuration + */ + @Override + public OAuth2LoginConfigurer securityContextRepository(SecurityContextRepository securityContextRepository) { + this.securityContextRepository = securityContextRepository; + return this; + } + /** * Sets the registry for managing the OIDC client-provider session link * @param oidcSessionRegistry the {@link OidcSessionRegistry} to use @@ -348,6 +362,9 @@ public final class OAuth2LoginConfigurer> OAuth2LoginAuthenticationFilter authenticationFilter = new OAuth2LoginAuthenticationFilter( this.getClientRegistrationRepository(), this.getAuthorizedClientRepository(), this.loginProcessingUrl); authenticationFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy()); + if (this.securityContextRepository != null) { + authenticationFilter.setSecurityContextRepository(this.securityContextRepository); + } this.setAuthenticationFilter(authenticationFilter); super.loginProcessingUrl(this.loginProcessingUrl); if (this.loginPage != null) { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index dfe6fea28f..d7bcfb4e33 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -101,6 +101,7 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; +import org.springframework.security.web.context.NullSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.session.HttpSessionDestroyedEvent; import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher; @@ -110,6 +111,7 @@ import org.springframework.web.context.support.AnnotationConfigWebApplicationCon import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; @@ -696,6 +698,12 @@ public class OAuth2LoginConfigurerTests { verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository); } + // gh-16623 + @Test + public void oauth2LoginWithCustomSecurityContextRepository() { + assertThatNoException().isThrownBy(() -> loadConfig(OAuth2LoginConfigSecurityContextRepository.class)); + } + private void loadConfig(Class... configs) { AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext(); applicationContext.register(configs); @@ -944,6 +952,24 @@ public class OAuth2LoginConfigurerTests { } + @Configuration + @EnableWebSecurity + static class OAuth2LoginConfigSecurityContextRepository extends CommonSecurityFilterChainConfig { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((login) -> login + .clientRegistrationRepository( + new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) + .securityContextRepository(new NullSecurityContextRepository())); + // @formatter:on + return super.configureFilterChain(http); + } + + } + @Configuration @EnableWebSecurity static class OAuth2LoginConfigCustomAuthorizationRequestResolver extends CommonSecurityFilterChainConfig {