diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java index 446cbf3bc6..4ddda990fd 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java @@ -16,10 +16,6 @@ package org.springframework.security.saml2.provider.service.web; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - import jakarta.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -27,13 +23,10 @@ import org.apache.commons.logging.LogFactory; import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; -import org.springframework.security.web.util.UrlUtils; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; -import org.springframework.web.util.UriComponents; -import org.springframework.web.util.UriComponentsBuilder; /** * A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the @@ -48,8 +41,6 @@ public final class DefaultRelyingPartyRegistrationResolver private Log logger = LogFactory.getLog(getClass()); - private static final char PATH_DELIMITER = '/'; - private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; private final RequestMatcher registrationRequestMatcher = new AntPathRequestMatcher("/**/{registrationId}"); @@ -87,61 +78,19 @@ public final class DefaultRelyingPartyRegistrationResolver } return null; } - RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository + RelyingPartyRegistration registration = this.relyingPartyRegistrationRepository .findByRegistrationId(relyingPartyRegistrationId); - if (relyingPartyRegistration == null) { + if (registration == null) { return null; } - String applicationUri = getApplicationUri(request); - Function templateResolver = templateResolver(applicationUri, relyingPartyRegistration); - String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId()); - String assertionConsumerServiceLocation = templateResolver - .apply(relyingPartyRegistration.getAssertionConsumerServiceLocation()); - String singleLogoutServiceLocation = templateResolver - .apply(relyingPartyRegistration.getSingleLogoutServiceLocation()); - String singleLogoutServiceResponseLocation = templateResolver - .apply(relyingPartyRegistration.getSingleLogoutServiceResponseLocation()); - return relyingPartyRegistration.mutate().entityId(relyingPartyEntityId) - .assertionConsumerServiceLocation(assertionConsumerServiceLocation) - .singleLogoutServiceLocation(singleLogoutServiceLocation) - .singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation).build(); - } - - private Function templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) { - return (template) -> resolveUrlTemplate(template, applicationUri, relyingParty); - } - - private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { - if (template == null) { - return null; - } - String entityId = relyingParty.getAssertingPartyDetails().getEntityId(); - String registrationId = relyingParty.getRegistrationId(); - Map uriVariables = new HashMap<>(); - UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl).replaceQuery(null).fragment(null) + UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); + return registration.mutate().entityId(uriResolver.resolve(registration.getEntityId())) + .assertionConsumerServiceLocation( + uriResolver.resolve(registration.getAssertionConsumerServiceLocation())) + .singleLogoutServiceLocation(uriResolver.resolve(registration.getSingleLogoutServiceLocation())) + .singleLogoutServiceResponseLocation( + uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation())) .build(); - String scheme = uriComponents.getScheme(); - uriVariables.put("baseScheme", (scheme != null) ? scheme : ""); - String host = uriComponents.getHost(); - uriVariables.put("baseHost", (host != null) ? host : ""); - // following logic is based on HierarchicalUriComponents#toUriString() - int port = uriComponents.getPort(); - uriVariables.put("basePort", (port == -1) ? "" : ":" + port); - String path = uriComponents.getPath(); - if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) { - path = PATH_DELIMITER + path; - } - uriVariables.put("basePath", (path != null) ? path : ""); - uriVariables.put("baseUrl", uriComponents.toUriString()); - uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); - uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); - return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString(); - } - - private static String getApplicationUri(HttpServletRequest request) { - UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) - .replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build(); - return uriComponents.toUriString(); } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationPlaceholderResolvers.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationPlaceholderResolvers.java new file mode 100644 index 0000000000..b52857a2b6 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationPlaceholderResolvers.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import java.util.HashMap; +import java.util.Map; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.web.util.UrlUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * A factory for creating placeholder resolvers for {@link RelyingPartyRegistration} + * templates. Supports {@code baseUrl}, {@code baseScheme}, {@code baseHost}, + * {@code basePort}, {@code basePath}, {@code registrationId}, + * {@code relyingPartyEntityId}, and {@code assertingPartyEntityId} + * + * @author Josh Cummings + * @since 6.1 + */ +public final class RelyingPartyRegistrationPlaceholderResolvers { + + private static final char PATH_DELIMITER = '/'; + + private RelyingPartyRegistrationPlaceholderResolvers() { + + } + + /** + * Create a resolver based on the given {@link HttpServletRequest}. Given the request, + * placeholders {@code baseUrl}, {@code baseScheme}, {@code baseHost}, + * {@code basePort}, and {@code basePath} are resolved. + * @param request the HTTP request + * @return a resolver that can resolve {@code baseUrl}, {@code baseScheme}, + * {@code baseHost}, {@code basePort}, and {@code basePath} placeholders + */ + public static UriResolver uriResolver(HttpServletRequest request) { + return new UriResolver(uriVariables(request)); + } + + /** + * Create a resolver based on the given {@link HttpServletRequest}. Given the request, + * placeholders {@code baseUrl}, {@code baseScheme}, {@code baseHost}, + * {@code basePort}, {@code basePath}, {@code registrationId}, + * {@code assertingPartyEntityId}, and {@code relyingPartyEntityId} are resolved. + * @param request the HTTP request + * @return a resolver that can resolve {@code baseUrl}, {@code baseScheme}, + * {@code baseHost}, {@code basePort}, {@code basePath}, {@code registrationId}, + * {@code relyingPartyEntityId}, and {@code assertingPartyEntityId} placeholders + */ + public static UriResolver uriResolver(HttpServletRequest request, RelyingPartyRegistration registration) { + String relyingPartyEntityId = registration.getEntityId(); + String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId(); + String registrationId = registration.getRegistrationId(); + Map uriVariables = uriVariables(request); + uriVariables.put("relyingPartyEntityId", StringUtils.hasText(relyingPartyEntityId) ? relyingPartyEntityId : ""); + uriVariables.put("assertingPartyEntityId", + StringUtils.hasText(assertingPartyEntityId) ? assertingPartyEntityId : ""); + uriVariables.put("entityId", StringUtils.hasText(assertingPartyEntityId) ? assertingPartyEntityId : ""); + uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); + return new UriResolver(uriVariables); + } + + private static Map uriVariables(HttpServletRequest request) { + String baseUrl = getApplicationUri(request); + Map uriVariables = new HashMap<>(); + UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl).replaceQuery(null).fragment(null) + .build(); + String scheme = uriComponents.getScheme(); + uriVariables.put("baseScheme", (scheme != null) ? scheme : ""); + String host = uriComponents.getHost(); + uriVariables.put("baseHost", (host != null) ? host : ""); + // following logic is based on HierarchicalUriComponents#toUriString() + int port = uriComponents.getPort(); + uriVariables.put("basePort", (port == -1) ? "" : ":" + port); + String path = uriComponents.getPath(); + if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) { + path = PATH_DELIMITER + path; + } + uriVariables.put("basePath", (path != null) ? path : ""); + uriVariables.put("baseUrl", uriComponents.toUriString()); + return uriVariables; + } + + private static String getApplicationUri(HttpServletRequest request) { + UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) + .replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build(); + return uriComponents.toUriString(); + } + + /** + * A class for resolving {@link RelyingPartyRegistration} URIs + */ + public static final class UriResolver { + + private final Map uriVariables; + + private UriResolver(Map uriVariables) { + this.uriVariables = uriVariables; + } + + public String resolve(String uri) { + if (uri == null) { + return null; + } + return UriComponentsBuilder.fromUriString(uri).buildAndExpand(this.uriVariables).toUriString(); + } + + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationPlaceholderResolversTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationPlaceholderResolversTests.java new file mode 100644 index 0000000000..3ceea878d8 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationPlaceholderResolversTests.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import org.junit.jupiter.api.Test; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * Tests for {@link RelyingPartyRegistrationPlaceholderResolvers} + */ +public class RelyingPartyRegistrationPlaceholderResolversTests { + + @Test + void uriResolverGivenRequestCreatesResolver() { + MockHttpServletRequest request = new MockHttpServletRequest(); + UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request); + String resolved = uriResolver.resolve("{baseUrl}/extension"); + assertThat(resolved).isEqualTo("http://localhost/extension"); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> uriResolver.resolve("{baseUrl}/extension/{registrationId}")); + } + + @Test + void uriResolverGivenRequestAndRegistrationCreatesResolver() { + MockHttpServletRequest request = new MockHttpServletRequest(); + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration() + .entityId("http://sp.example.org").build(); + UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); + String resolved = uriResolver.resolve("{baseUrl}/extension/{registrationId}"); + assertThat(resolved).isEqualTo("http://localhost/extension/simplesamlphp"); + resolved = uriResolver.resolve("{relyingPartyEntityId}/extension"); + assertThat(resolved).isEqualTo("http://sp.example.org/extension"); + } + +}