diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index 3b8220e80e..9860e8040b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -47,6 +47,7 @@ import org.springframework.security.saml2.provider.service.servlet.filter.Saml2W import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; import org.springframework.security.web.authentication.AuthenticationConverter; @@ -264,7 +265,8 @@ public final class Saml2LoginConfigurer> private AuthenticationConverter getAuthenticationConverter(B http) { if (this.authenticationConverter == null) { return new Saml2AuthenticationTokenConverter( - new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository)); + (RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver( + this.relyingPartyRegistrationRepository)); } return this.authenticationConverter; } @@ -390,8 +392,9 @@ public final class Saml2LoginConfigurer> Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class); if (resolver == null) { - return new DefaultSaml2AuthenticationRequestContextResolver(new DefaultRelyingPartyRegistrationResolver( - Saml2LoginConfigurer.this.relyingPartyRegistrationRepository)); + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver( + Saml2LoginConfigurer.this.relyingPartyRegistrationRepository); + return new DefaultSaml2AuthenticationRequestContextResolver(relyingPartyRegistrationResolver); } return resolver; } diff --git a/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc b/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc index 303351fb79..33e7c5ad35 100644 --- a/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc +++ b/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc @@ -727,7 +727,7 @@ There are a number of reasons you may want to customize. Among them: * You may know that you will never be a multi-tenant application and so want to have a simpler URL scheme * You may identify tenants in a way other than by the URI path -To customize the way that a `RelyingPartyRegistration` is resolved, you can configure a custom `Converter`. +To customize the way that a `RelyingPartyRegistration` is resolved, you can configure a custom `RelyingPartyRegistrationResolver`. The default looks up the registration id from the URI's last path element and looks it up in your `RelyingPartyRegistrationRepository`. You can provide a simpler resolver that, for example, always returns the same relying party: @@ -736,12 +736,17 @@ You can provide a simpler resolver that, for example, always returns the same re .Java [source,java,role="primary"] ---- -public class SingleRelyingPartyRegistrationResolver - implements Converter { +public class SingleRelyingPartyRegistrationResolver implements RelyingPartyRegistrationResolver { + + private final RelyingPartyRegistrationResolver delegate; + + public SingleRelyingPartyRegistrationResolver(RelyingPartyRegistrationRepository registrations) { + this.delegate = new DefaultRelyingPartyRegistrationResolver(registrations); + } @Override - public RelyingPartyRegistration convert(HttpServletRequest request) { - return this.relyingParty; + public RelyingPartyRegistration resolve(HttpServletRequest request, String registrationId) { + return this.delegate.resolve(request, "single"); } } ---- @@ -749,9 +754,9 @@ public class SingleRelyingPartyRegistrationResolver .Kotlin [source,kotlin,role="secondary"] ---- -class SingleRelyingPartyRegistrationResolver : Converter { - override fun convert(request: HttpServletRequest?): RelyingPartyRegistration? { - return this.relyingParty +class SingleRelyingPartyRegistrationResolver(delegate: RelyingPartyRegistrationResolver) : RelyingPartyRegistrationResolver { + override fun resolve(request: HttpServletRequest?, registrationId: String?): RelyingPartyRegistration? { + return this.delegate.resolve(request, "single") } } ---- @@ -1544,7 +1549,7 @@ You can publish a metadata endpoint by adding the `Saml2MetadataFilter` to the f .Java [source,java,role="primary"] ---- -Converter relyingPartyRegistrationResolver = +DefaultRelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository); Saml2MetadataFilter filter = new Saml2MetadataFilter( relyingPartyRegistrationResolver, @@ -1594,8 +1599,6 @@ filter.setRequestMatcher(AntPathRequestMatcher("/saml2/metadata/{registrationId} ---- ==== -ensuring that the `registrationId` hint is at the end of the path. - Or, if you have registered a custom relying party registration resolver in the constructor, then you can specify a path without a `registrationId` hint, like so: ==== diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java index c59ec4deeb..b5fc9e01b3 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java @@ -29,6 +29,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.authentication.AuthenticationConverter; @@ -67,7 +68,9 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce public Saml2WebSsoAuthenticationFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, String filterProcessesUrl) { this(new Saml2AuthenticationTokenConverter( - new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), filterProcessesUrl); + (RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver( + relyingPartyRegistrationRepository)), + filterProcessesUrl); } /** diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java index 39819a513a..b1ceadd08f 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java @@ -39,6 +39,7 @@ import org.springframework.security.saml2.provider.service.servlet.HttpSessionSa import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -96,7 +97,9 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter public Saml2WebSsoAuthenticationRequestFilter( RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { this(new DefaultSaml2AuthenticationRequestContextResolver( - new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), requestFactory()); + (RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver( + relyingPartyRegistrationRepository)), + requestFactory()); } private static Saml2AuthenticationRequestFactory requestFactory() { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java index 10b667847c..ce8ae7e448 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java @@ -22,6 +22,9 @@ import java.util.function.Function; import javax.servlet.http.HttpServletRequest; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; @@ -42,13 +45,15 @@ import org.springframework.web.util.UriComponentsBuilder; * @since 5.4 */ public final class DefaultRelyingPartyRegistrationResolver - implements Converter { + implements RelyingPartyRegistrationResolver, Converter { + + private Log logger = LogFactory.getLog(getClass()); private static final char PATH_DELIMITER = '/'; private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; - private final Converter registrationIdResolver = new RegistrationIdResolver(); + private final RequestMatcher registrationRequestMatcher = new AntPathRequestMatcher("/**/{registrationId}"); public DefaultRelyingPartyRegistrationResolver( RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { @@ -56,14 +61,35 @@ public final class DefaultRelyingPartyRegistrationResolver this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; } + /** + * {@inheritDoc} + */ @Override public RelyingPartyRegistration convert(HttpServletRequest request) { - String registrationId = this.registrationIdResolver.convert(request); - if (registrationId == null) { + return resolve(request, null); + } + + /** + * {@inheritDoc} + */ + @Override + public RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId) { + if (relyingPartyRegistrationId == null) { + if (this.logger.isTraceEnabled()) { + this.logger.trace("Attempting to resolve from " + this.registrationRequestMatcher + + " since registrationId is null"); + } + relyingPartyRegistrationId = this.registrationRequestMatcher.matcher(request).getVariables() + .get("registrationId"); + } + if (relyingPartyRegistrationId == null) { + if (this.logger.isTraceEnabled()) { + this.logger.trace("Returning null registration since registrationId is null"); + } return null; } RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository - .findByRegistrationId(registrationId); + .findByRegistrationId(relyingPartyRegistrationId); if (relyingPartyRegistration == null) { return null; } @@ -111,16 +137,4 @@ public final class DefaultRelyingPartyRegistrationResolver return uriComponents.toUriString(); } - private static class RegistrationIdResolver implements Converter { - - private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}"); - - @Override - public String convert(HttpServletRequest request) { - RequestMatcher.MatchResult result = this.requestMatcher.matcher(request); - return result.getVariables().get("registrationId"); - } - - } - } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java index a6cdb3ed91..d95472e8e3 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java @@ -42,11 +42,24 @@ public final class DefaultSaml2AuthenticationRequestContextResolver private final Converter relyingPartyRegistrationResolver; + /** + * Construct a {@link DefaultSaml2AuthenticationRequestContextResolver} + * @param relyingPartyRegistrationResolver + * @deprecated Use + * {@link DefaultSaml2AuthenticationRequestContextResolver#DefaultSaml2AuthenticationRequestContextResolver(RelyingPartyRegistrationResolver)} + * instead + */ + @Deprecated public DefaultSaml2AuthenticationRequestContextResolver( Converter relyingPartyRegistrationResolver) { this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; } + public DefaultSaml2AuthenticationRequestContextResolver( + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + this.relyingPartyRegistrationResolver = (request) -> relyingPartyRegistrationResolver.resolve(request, null); + } + @Override public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationResolver.java new file mode 100644 index 0000000000..d9e5e0eb14 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationResolver.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; + +/** + * A contract for resolving a {@link RelyingPartyRegistration} from the HTTP request + * + * @author Josh Cummings + * @since 5.6 + */ +public interface RelyingPartyRegistrationResolver { + + /** + * Resolve a {@link RelyingPartyRegistration} from the HTTP request, using the + * {@code relyingPartyRegistrationId}, if it is provided + * @param request the HTTP request + * @param relyingPartyRegistrationId the {@link RelyingPartyRegistration} identifier + * @return the resolved {@link RelyingPartyRegistration} + */ + RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId); + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java index 91f8f3e95c..d0dfa986e9 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java @@ -61,7 +61,11 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo * resolving {@link RelyingPartyRegistration}s * @param relyingPartyRegistrationResolver the strategy for resolving * {@link RelyingPartyRegistration}s + * @deprecated Use + * {@link Saml2AuthenticationTokenConverter#Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver)} + * instead */ + @Deprecated public Saml2AuthenticationTokenConverter( Converter relyingPartyRegistrationResolver) { Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); @@ -69,6 +73,16 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest; } + public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + this(adaptToConverter(relyingPartyRegistrationResolver)); + } + + private static Converter adaptToConverter( + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); + return (request) -> relyingPartyRegistrationResolver.resolve(request, null); + } + @Override public Saml2AuthenticationToken convert(HttpServletRequest request) { RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java index 57ec493bc8..f01ce2d430 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java @@ -46,7 +46,7 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { public static final String DEFAULT_METADATA_FILE_NAME = "saml-{registrationId}-metadata.xml"; - private final Converter relyingPartyRegistrationConverter; + private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; private final Saml2MetadataResolver saml2MetadataResolver; @@ -55,11 +55,26 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { private RequestMatcher requestMatcher = new AntPathRequestMatcher( "/saml2/service-provider-metadata/{registrationId}"); - public Saml2MetadataFilter( - Converter relyingPartyRegistrationConverter, + /** + * Construct a {@link Saml2MetadataFilter} + * @param relyingPartyRegistrationResolver + * @param saml2MetadataResolver + * @deprecated Use + * {@link Saml2MetadataFilter#Saml2MetadataFilter(RelyingPartyRegistrationResolver)} + * instead + */ + @Deprecated + public Saml2MetadataFilter(Converter relyingPartyRegistrationResolver, Saml2MetadataResolver saml2MetadataResolver) { + this.relyingPartyRegistrationResolver = (request, id) -> relyingPartyRegistrationResolver.convert(request); + this.saml2MetadataResolver = saml2MetadataResolver; + } - this.relyingPartyRegistrationConverter = relyingPartyRegistrationConverter; + public Saml2MetadataFilter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver, + Saml2MetadataResolver saml2MetadataResolver) { + Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); + Assert.notNull(saml2MetadataResolver, "saml2MetadataResolver cannot be null"); + this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; this.saml2MetadataResolver = saml2MetadataResolver; } @@ -71,14 +86,15 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { chain.doFilter(request, response); return; } - RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationConverter.convert(request); + String registrationId = matcher.getVariables().get("registrationId"); + RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request, + registrationId); if (relyingPartyRegistration == null) { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); return; } String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration); - String registrationId = relyingPartyRegistration.getRegistrationId(); - writeMetadataToResponse(response, registrationId, metadata); + writeMetadataToResponse(response, relyingPartyRegistration.getRegistrationId(), metadata); } private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata) diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java index 11b07fd2fc..914f370154 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java @@ -22,15 +22,25 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationTokens; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -49,6 +59,8 @@ public class Saml2WebSsoAuthenticationFilterTests { private HttpServletResponse response = new MockHttpServletResponse(); + private AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + @BeforeEach public void setup() { this.filter = new Saml2WebSsoAuthenticationFilter(this.repository); @@ -132,4 +144,26 @@ public class Saml2WebSsoAuthenticationFilterTests { verifyNoInteractions(authenticationConverter); } + @Test + public void doFilterWhenPathStartsWithRegistrationIdThenAuthenticates() throws Exception { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); + Authentication authentication = new TestingAuthenticationToken("user", "password"); + given(this.repository.findByRegistrationId("registration-id")).willReturn(registration); + given(this.authenticationManager.authenticate(authentication)).willReturn(authentication); + String loginProcessingUrl = "/{registrationId}/login/saml2/sso"; + RequestMatcher matcher = new AntPathRequestMatcher(loginProcessingUrl); + DefaultRelyingPartyRegistrationResolver delegate = new DefaultRelyingPartyRegistrationResolver(this.repository); + RelyingPartyRegistrationResolver resolver = (request, id) -> { + String registrationId = matcher.matcher(request).getVariables().get("registrationId"); + return delegate.resolve(request, registrationId); + }; + Saml2AuthenticationTokenConverter authenticationConverter = new Saml2AuthenticationTokenConverter(resolver); + this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, loginProcessingUrl); + this.filter.setAuthenticationManager(this.authenticationManager); + this.request.setPathInfo("/registration-id/login/saml2/sso"); + this.request.setParameter("SAMLResponse", "response"); + this.filter.doFilter(this.request, this.response, new MockFilterChain()); + verify(this.repository).findByRegistrationId("registration-id"); + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java index 0eda04f267..5afcc554b1 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java @@ -37,8 +37,14 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; +import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriUtils; @@ -256,4 +262,29 @@ public class Saml2WebSsoAuthenticationRequestFilterTests { any(Saml2PostAuthenticationRequest.class), eq(this.request), eq(this.response)); } + @Test + public void doFilterWhenPathStartsWithRegistrationIdThenPosts() throws Exception { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() + .assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)).build(); + RequestMatcher matcher = new AntPathRequestMatcher("/{registrationId}/saml2/authenticate"); + DefaultRelyingPartyRegistrationResolver delegate = new DefaultRelyingPartyRegistrationResolver(this.repository); + RelyingPartyRegistrationResolver resolver = (request, id) -> { + String registrationId = matcher.matcher(request).getVariables().get("registrationId"); + return delegate.resolve(request, registrationId); + }; + Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver( + resolver); + Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class); + given(authenticationRequest.getAuthenticationRequestUri()).willReturn("uri"); + given(authenticationRequest.getRelayState()).willReturn("relay"); + given(authenticationRequest.getSamlRequest()).willReturn("saml"); + given(this.repository.findByRegistrationId("registration-id")).willReturn(registration); + given(this.factory.createPostAuthenticationRequest(any())).willReturn(authenticationRequest); + this.filter = new Saml2WebSsoAuthenticationRequestFilter(authenticationRequestContextResolver, this.factory); + this.filter.setRedirectMatcher(matcher); + this.request.setPathInfo("/registration-id/saml2/authenticate"); + this.filter.doFilter(this.request, this.response, new MockFilterChain()); + verify(this.repository).findByRegistrationId("registration-id"); + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java index 273fe0484e..b73d7d65ab 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java @@ -49,8 +49,11 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests { private RelyingPartyRegistration.Builder relyingPartyBuilder; + private RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver( + (id) -> this.relyingPartyBuilder.build()); + private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver( - new DefaultRelyingPartyRegistrationResolver((id) -> this.relyingPartyBuilder.build())); + this.relyingPartyRegistrationResolver); @BeforeEach public void setup() { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java index e922dcc699..9fe6aef59b 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java @@ -176,7 +176,8 @@ public class Saml2AuthenticationTokenConverterTests { @Test public void constructorWhenResolverIsNullThenIllegalArgument() { - assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null)); + assertThatIllegalArgumentException() + .isThrownBy(() -> new Saml2AuthenticationTokenConverter((RelyingPartyRegistrationResolver) null)); } @Test diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java index ced8ad5a87..0f40eebdf3 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java @@ -25,6 +25,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.http.HttpHeaders; +import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.saml2.core.TestSaml2X509Credentials; @@ -37,6 +38,7 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -63,8 +65,9 @@ public class Saml2MetadataFilterTests { public void setup() { this.repository = mock(RelyingPartyRegistrationRepository.class); this.resolver = mock(Saml2MetadataResolver.class); - this.filter = new Saml2MetadataFilter(new DefaultRelyingPartyRegistrationResolver(this.repository), - this.resolver); + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver( + this.repository); + this.filter = new Saml2MetadataFilter(relyingPartyRegistrationResolver, this.resolver); this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); this.chain = mock(FilterChain.class); @@ -136,6 +139,20 @@ public class Saml2MetadataFilterTests { .isEqualTo("attachment; filename=\"%s\"; filename*=UTF-8''%s", fileName, encodedFileName); } + @Test + public void doFilterWhenPathStartsWithRegistrationIdThenServesMetadata() throws Exception { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); + given(this.repository.findByRegistrationId("registration-id")).willReturn(registration); + given(this.resolver.resolve(any())).willReturn("metadata"); + RelyingPartyRegistrationResolver resolver = new DefaultRelyingPartyRegistrationResolver( + (id) -> this.repository.findByRegistrationId("registration-id")); + this.filter = new Saml2MetadataFilter(resolver, this.resolver); + this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata")); + this.request.setPathInfo("/metadata"); + this.filter.doFilter(this.request, this.response, new MockFilterChain()); + verify(this.repository).findByRegistrationId("registration-id"); + } + @Test public void setRequestMatcherWhenNullThenIllegalArgument() { assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestMatcher(null));