diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java index 01f133dfea..623a04cbba 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java @@ -16,13 +16,6 @@ package org.springframework.security.saml2.provider.service.servlet.filter; -import java.io.IOException; -import java.util.function.Function; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import org.springframework.http.MediaType; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; @@ -31,6 +24,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher.MatchResult; @@ -41,6 +36,12 @@ import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriUtils; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; + import static java.nio.charset.StandardCharsets.ISO_8859_1; /** @@ -70,6 +71,7 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; private Saml2AuthenticationRequestFactory authenticationRequestFactory; + private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver(); private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); @@ -121,6 +123,17 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter this.redirectMatcher = redirectMatcher; } + /** + * Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext} + * + * @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use + * @since 5.4 + */ + public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) { + Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null"); + this.authenticationRequestContextResolver = authenticationRequestContextResolver; + } + /** * {@inheritDoc} */ @@ -141,38 +154,14 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter response.sendError(HttpServletResponse.SC_UNAUTHORIZED); return; } - if (this.logger.isDebugEnabled()) { - this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" + - relyingParty.getRegistrationId() + "]"); - } - Saml2AuthenticationRequestContext context = createRedirectAuthenticationRequestContext(request, relyingParty); + Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty); if (relyingParty.getProviderDetails().getBinding() == Saml2MessageBinding.REDIRECT) { sendRedirect(response, context); - } - else { + } else { sendPost(response, context); } } - private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext( - HttpServletRequest request, RelyingPartyRegistration relyingParty) { - - String applicationUri = Saml2ServletUtils.getApplicationUri(request); - Function resolver = templateResolver(applicationUri, relyingParty); - String localSpEntityId = resolver.apply(relyingParty.getLocalEntityIdTemplate()); - String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceUrlTemplate()); - return Saml2AuthenticationRequestContext.builder() - .issuer(localSpEntityId) - .relyingPartyRegistration(relyingParty) - .assertionConsumerServiceUrl(assertionConsumerServiceUrl) - .relayState(request.getParameter("RelayState")) - .build(); - } - - private Function templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) { - return template -> Saml2ServletUtils.resolveUrlTemplate(template, applicationUri, relyingParty); - } - private void sendRedirect(HttpServletResponse response, Saml2AuthenticationRequestContext context) throws IOException { Saml2RedirectAuthenticationRequest authenticationRequest = diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java new file mode 100644 index 0000000000..2209ff658f --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +import javax.servlet.http.HttpServletRequest; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl; +import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl; + +/** + * The default implementation for {@link Saml2AuthenticationRequestContextResolver} + * which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext} + * + * @author Shazin Sadakath + * @since 5.4 + */ +public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver { + + private final Log logger = LogFactory.getLog(getClass()); + + private static final char PATH_DELIMITER = '/'; + + /** + * {@inheritDoc} + */ + @Override + public Saml2AuthenticationRequestContext resolve(HttpServletRequest request, + RelyingPartyRegistration relyingParty) { + Assert.notNull(request, "request cannot be null"); + Assert.notNull(relyingParty, "relyingParty cannot be null"); + if (this.logger.isDebugEnabled()) { + this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" + + relyingParty.getRegistrationId() + "]"); + } + return createRedirectAuthenticationRequestContext(request, relyingParty); + } + + private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext( + HttpServletRequest request, RelyingPartyRegistration relyingParty) { + + String applicationUri = getApplicationUri(request); + Function resolver = templateResolver(applicationUri, relyingParty); + String localSpEntityId = resolver.apply(relyingParty.getLocalEntityIdTemplate()); + String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceUrlTemplate()); + return Saml2AuthenticationRequestContext.builder() + .issuer(localSpEntityId) + .relyingPartyRegistration(relyingParty) + .assertionConsumerServiceUrl(assertionConsumerServiceUrl) + .relayState(request.getParameter("RelayState")) + .build(); + } + + private Function templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) { + return template -> resolveUrlTemplate(template, applicationUri, relyingParty); + } + + private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { + if (!StringUtils.hasText(template)) { + return baseUrl; + } + + String entityId = relyingParty.getProviderDetails().getEntityId(); + String registrationId = relyingParty.getRegistrationId(); + Map uriVariables = new HashMap<>(); + UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl) + .replaceQuery(null) + .fragment(null) + .build(); + String scheme = uriComponents.getScheme(); + uriVariables.put("baseScheme", scheme == null ? "" : scheme); + String host = uriComponents.getHost(); + uriVariables.put("baseHost", host == null ? "" : host); + // following logic is based on HierarchicalUriComponents#toUriString() + int port = uriComponents.getPort(); + uriVariables.put("basePort", port == -1 ? "" : ":" + port); + String path = uriComponents.getPath(); + if (StringUtils.hasLength(path)) { + if (path.charAt(0) != PATH_DELIMITER) { + path = PATH_DELIMITER + path; + } + } + uriVariables.put("basePath", path == null ? "" : path); + uriVariables.put("baseUrl", uriComponents.toUriString()); + uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); + uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); + + return UriComponentsBuilder.fromUriString(template) + .buildAndExpand(uriVariables) + .toUriString(); + } + + private static String getApplicationUri(HttpServletRequest request) { + UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request)) + .replacePath(request.getContextPath()) + .replaceQuery(null) + .fragment(null) + .build(); + return uriComponents.toUriString(); + } +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java new file mode 100644 index 0000000000..1c86ec239e --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; + +import javax.servlet.http.HttpServletRequest; + +/** + * This {@code Saml2AuthenticationRequestContextResolver} formulates a + * SAML 2.0 AuthnRequest (line 1968) + * + * @author Shazin Sadakath + * @since 5.4 + */ +public interface Saml2AuthenticationRequestContextResolver { + + /** + * This {@code resolve} method is defined to create a {@link Saml2AuthenticationRequestContext} + * + * + * @param request the current request + * @param relyingParty the relying party responsible for saml2 sso authentication + * @return the created {@link Saml2AuthenticationRequestContext} for request/relying party combination + */ + Saml2AuthenticationRequestContext resolve(HttpServletRequest request, + RelyingPartyRegistration relyingParty); +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/TestSaml2SigningCredentials.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/TestSaml2SigningCredentials.java index 3aa718227e..cec591cbed 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/TestSaml2SigningCredentials.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/TestSaml2SigningCredentials.java @@ -31,9 +31,9 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION; import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING; -final class TestSaml2SigningCredentials { +public final class TestSaml2SigningCredentials { - static Saml2X509Credential signingCredential() { + public static Saml2X509Credential signingCredential() { return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), SIGNING, DECRYPTION); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java new file mode 100644 index 0000000000..4044c582d8 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; + +import static org.springframework.security.saml2.provider.service.servlet.filter.TestSaml2SigningCredentials.signingCredential; +import static org.assertj.core.api.Assertions.*; + +public class DefaultSaml2AuthenticationRequestContextResolverTests { + + private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO"; + private static final String TEMPLATE = "template"; + private static final String REGISTRATION_ID = "registration-id"; + private static final String IDP_ENTITY_ID = "idp-entity-id"; + + private MockHttpServletRequest request; + private RelyingPartyRegistration.Builder rpBuilder; + private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver(); + + @Before + public void setup() { + request = new MockHttpServletRequest(); + rpBuilder = RelyingPartyRegistration + .withRegistrationId(REGISTRATION_ID) + .providerDetails(c -> c.entityId(IDP_ENTITY_ID)) + .providerDetails(c -> c.webSsoUrl(IDP_SSO_URL)) + .assertionConsumerServiceUrlTemplate(TEMPLATE) + .credentials(c -> c.add(signingCredential())); + } + + @Test + public void resoleWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() { + Saml2AuthenticationRequestContext authenticationRequestContext = authenticationRequestContextResolver.resolve(request, rpBuilder.build()); + + assertThat(authenticationRequestContext).isNotNull(); + assertThat(authenticationRequestContext.getAssertionConsumerServiceUrl()).isEqualTo(TEMPLATE); + assertThat(authenticationRequestContext.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(REGISTRATION_ID); + assertThat(authenticationRequestContext.getRelyingPartyRegistration().getProviderDetails().getEntityId()).isEqualTo(IDP_ENTITY_ID); + assertThat(authenticationRequestContext.getRelyingPartyRegistration().getProviderDetails().getWebSsoUrl()).isEqualTo(IDP_SSO_URL); + assertThat(authenticationRequestContext.getRelyingPartyRegistration().getCredentials()).isNotEmpty(); + } + + @Test(expected = IllegalArgumentException.class) + public void resolveWhenRequestAndRelyingPartyNullThenException() { + authenticationRequestContextResolver.resolve(null, null); + } +}