diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index b953520365..45b917ed6c 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -16,6 +16,10 @@ package org.springframework.security.config.annotation.web.configurers.saml2; +import java.util.LinkedHashMap; +import java.util.Map; +import javax.servlet.Filter; + import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.security.authentication.AuthenticationManager; @@ -37,10 +41,6 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; -import java.util.LinkedHashMap; -import java.util.Map; -import javax.servlet.Filter; - import static org.springframework.util.StringUtils.hasText; /** @@ -323,10 +323,9 @@ public final class Saml2LoginConfigurer> extend private Filter build(B http) { Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http); - Saml2WebSsoAuthenticationRequestFilter authenticationRequestFilter = - new Saml2WebSsoAuthenticationRequestFilter(Saml2LoginConfigurer.this.relyingPartyRegistrationRepository); - authenticationRequestFilter.setAuthenticationRequestFactory(authenticationRequestResolver); - return authenticationRequestFilter; + return new Saml2WebSsoAuthenticationRequestFilter( + Saml2LoginConfigurer.this.relyingPartyRegistrationRepository, + authenticationRequestResolver); } private Saml2AuthenticationRequestFactory getResolver(B http) { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java index 0d07368347..8f8051c5bf 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java @@ -24,7 +24,6 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.springframework.http.MediaType; -import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; @@ -71,24 +70,43 @@ import static org.springframework.util.StringUtils.hasText; public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter { private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + private Saml2AuthenticationRequestFactory authenticationRequestFactory; + private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); - private Saml2AuthenticationRequestFactory authenticationRequestFactory = new OpenSamlAuthenticationRequestFactory(); /** * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters * * @param relyingPartyRegistrationRepository a repository for relying party configurations + * @deprecated use the constructor that takes a {@link Saml2AuthenticationRequestFactory} */ + @Deprecated public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + this(relyingPartyRegistrationRepository, + new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory()); + } + + /** + * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters + * + * @param relyingPartyRegistrationRepository a repository for relying party configurations + * @since 5.4 + */ + public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, + Saml2AuthenticationRequestFactory authenticationRequestFactory) { Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null"); + Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null"); this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; + this.authenticationRequestFactory = authenticationRequestFactory; } /** * Use the given {@link Saml2AuthenticationRequestFactory} for formulating the SAML 2.0 AuthnRequest * * @param authenticationRequestFactory the {@link Saml2AuthenticationRequestFactory} to use + * @deprecated use the constructor instead */ + @Deprecated public void setAuthenticationRequestFactory(Saml2AuthenticationRequestFactory authenticationRequestFactory) { Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null"); this.authenticationRequestFactory = authenticationRequestFactory; diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java index 8c51ce7b31..91d8e92f97 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java @@ -179,6 +179,29 @@ public class Saml2WebSsoAuthenticationRequestFilterTests { verify(this.factory).createPostAuthenticationRequest(any()); } + @Test + public void doFilterWhenCustomAuthenticationRequestFactoryThenUses() throws Exception { + RelyingPartyRegistration relyingParty = this.rpBuilder + .providerDetails(c -> c.binding(POST)) + .build(); + Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class); + when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri"); + when(authenticationRequest.getRelayState()).thenReturn("relay"); + when(authenticationRequest.getSamlRequest()).thenReturn("saml"); + when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty); + when(this.factory.createPostAuthenticationRequest(any())) + .thenReturn(authenticationRequest); + + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter + (this.repository, this.factory); + filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.response.getContentAsString()) + .contains("
") + .contains("