diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index c27cb4f2a7..27dfc8f7aa 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -350,11 +350,9 @@ public final class OAuth2LoginConfigurer> oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper); oidcAuthorizedClientRefreshedEventListener.setAuthoritiesMapper(userAuthoritiesMapper); } - oidcAuthorizationCodeAuthenticationProvider = this.postProcess(oidcAuthorizationCodeAuthenticationProvider); - http.authenticationProvider(oidcAuthorizationCodeAuthenticationProvider); + http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider)); - oidcAuthorizedClientRefreshedEventListener = this.postProcess(oidcAuthorizedClientRefreshedEventListener); - registerDelegateApplicationListener(oidcAuthorizedClientRefreshedEventListener); + registerDelegateApplicationListener(this.postProcess(oidcAuthorizedClientRefreshedEventListener)); configureOidcUserRefreshedEventListener(http); } else { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index 6428609517..cc1a30a381 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -29,6 +29,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mockito; import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.NoUniqueBeanDefinitionException; @@ -43,11 +44,14 @@ import org.springframework.http.MediaType; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.event.AuthenticationSuccessEvent; import org.springframework.security.config.Customizer; +import org.springframework.security.config.ObjectPostProcessor; import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configurers.oauth2.client.OAuth2LoginConfigurerTests.OAuth2LoginConfigCustomWithPostProcessor.SpyObjectPostProcessor; import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; @@ -711,6 +715,22 @@ public class OAuth2LoginConfigurerTests { verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository); } + // gh-17175 + @Test + public void oauth2LoginWhenAuthenticationProviderPostProcessorThenUses() throws Exception { + loadConfig(OAuth2LoginConfigCustomWithPostProcessor.class); + // setup authorization request + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); + // setup authentication parameters + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + // perform test + this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); + // assertions + verify(this.context.getBean(SpyObjectPostProcessor.class).spy).authenticate(any()); + } + private void loadConfig(Class... configs) { AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext(); applicationContext.register(configs); @@ -1296,6 +1316,52 @@ public class OAuth2LoginConfigurerTests { } + @Configuration + @EnableWebSecurity + static class OAuth2LoginConfigCustomWithPostProcessor { + + private final ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository( + GOOGLE_CLIENT_REGISTRATION); + + private final ObjectPostProcessor postProcessor = new SpyObjectPostProcessor(); + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((oauth2Login) -> oauth2Login + .clientRegistrationRepository(this.clientRegistrationRepository) + .withObjectPostProcessor(this.postProcessor) + ); + // @formatter:on + return http.build(); + } + + @Bean + ObjectPostProcessor mockPostProcessor() { + return this.postProcessor; + } + + @Bean + HttpSessionOAuth2AuthorizationRequestRepository oauth2AuthorizationRequestRepository() { + return new HttpSessionOAuth2AuthorizationRequestRepository(); + } + + static class SpyObjectPostProcessor implements ObjectPostProcessor { + + AuthenticationProvider spy; + + @Override + public O postProcess(O object) { + O spy = Mockito.spy(object); + this.spy = spy; + return spy; + } + + } + + } + private abstract static class CommonSecurityFilterChainConfig { SecurityFilterChain configureFilterChain(HttpSecurity http) throws Exception {