diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index a7aaf14f90..67c2422c58 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -270,6 +270,9 @@ public final class Saml2LoginConfigurer> } } this.initDefaultLoginFilter(http); + if (this.authenticationManager == null) { + registerDefaultAuthenticationProvider(http); + } } /** @@ -285,10 +288,7 @@ public final class Saml2LoginConfigurer> filter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http)); http.addFilter(postProcess(filter)); super.configure(http); - if (this.authenticationManager == null) { - registerDefaultAuthenticationProvider(http); - } - else { + if (this.authenticationManager != null) { this.saml2WebSsoAuthenticationFilter.setAuthenticationManager(this.authenticationManager); } } @@ -361,7 +361,10 @@ public final class Saml2LoginConfigurer> } private void registerDefaultAuthenticationProvider(B http) { - http.authenticationProvider(postProcess(new OpenSaml4AuthenticationProvider())); + OpenSaml4AuthenticationProvider provider = getBeanOrNull(http, OpenSaml4AuthenticationProvider.class); + if (provider == null) { + http.authenticationProvider(postProcess(new OpenSaml4AuthenticationProvider())); + } } private void registerDefaultCsrfOverride(B http) { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index 105d6f469a..7a635d8ff8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -50,6 +50,7 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig; @@ -68,6 +69,7 @@ import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2Utils; import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal; import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; @@ -390,6 +392,15 @@ public class Saml2LoginConfigurerTests { .andExpect(redirectedUrl("http://localhost/saml2/authenticate/registration-id")); } + @Test + public void saml2LoginWhenCustomAuthenticationProviderThenUses() throws Exception { + this.spring.register(CustomAuthenticationProviderConfig.class).autowire(); + AuthenticationProvider provider = this.spring.getContext().getBean(AuthenticationProvider.class); + this.mvc.perform(post("/login/saml2/sso/registration-id").param("SAMLResponse", SIGNED_RESPONSE)) + .andExpect(status().isFound()); + verify(provider).authenticate(any()); + } + private void performSaml2Login(String expected) throws IOException, ServletException { // setup authentication parameters this.request.setRequestURI("/login/saml2/sso/registration-id"); @@ -700,6 +711,29 @@ public class Saml2LoginConfigurerTests { } + @Configuration + @EnableWebSecurity + @EnableWebMvc + @Import(Saml2LoginConfigBeans.class) + static class CustomAuthenticationProviderConfig { + + private final OpenSaml4AuthenticationProvider provider = spy(new OpenSaml4AuthenticationProvider()); + + @Bean + SecurityFilterChain web(HttpSecurity http) throws Exception { + http.authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated()) + .saml2Login(Customizer.withDefaults()); + + return http.build(); + } + + @Bean + AuthenticationProvider provider() { + return this.provider; + } + + } + static class Saml2LoginConfigBeans { @Bean