diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java new file mode 100644 index 0000000000..3768233fdf --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; +import javax.servlet.http.HttpServletRequest; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; +import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl; +import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl; + +/** + * A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the + * registration id from the request, querying a {@link RelyingPartyRegistrationRepository}, + * and resolving any template values. + * + * @since 5.4 + * @author Josh Cummings + */ +public final class DefaultRelyingPartyRegistrationResolver + implements Converter { + + private static final char PATH_DELIMITER = '/'; + + private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + private final Converter registrationIdResolver = new RegistrationIdResolver(); + + public DefaultRelyingPartyRegistrationResolver + (RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + + Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null"); + this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; + } + + @Override + public RelyingPartyRegistration convert(HttpServletRequest request) { + String registrationId = this.registrationIdResolver.convert(request); + if (registrationId == null) { + return null; + } + RelyingPartyRegistration relyingPartyRegistration = + this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId); + if (relyingPartyRegistration == null) { + return null; + } + + String applicationUri = getApplicationUri(request); + Function templateResolver = templateResolver(applicationUri, relyingPartyRegistration); + String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId()); + String assertionConsumerServiceLocation = templateResolver.apply( + relyingPartyRegistration.getAssertionConsumerServiceLocation()); + return withRelyingPartyRegistration(relyingPartyRegistration) + .entityId(relyingPartyEntityId) + .assertionConsumerServiceLocation(assertionConsumerServiceLocation) + .build(); + } + + private Function templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) { + return template -> resolveUrlTemplate(template, applicationUri, relyingParty); + } + + private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { + String entityId = relyingParty.getAssertingPartyDetails().getEntityId(); + String registrationId = relyingParty.getRegistrationId(); + Map uriVariables = new HashMap<>(); + UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl) + .replaceQuery(null) + .fragment(null) + .build(); + String scheme = uriComponents.getScheme(); + uriVariables.put("baseScheme", scheme == null ? "" : scheme); + String host = uriComponents.getHost(); + uriVariables.put("baseHost", host == null ? "" : host); + // following logic is based on HierarchicalUriComponents#toUriString() + int port = uriComponents.getPort(); + uriVariables.put("basePort", port == -1 ? "" : ":" + port); + String path = uriComponents.getPath(); + if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) { + path = PATH_DELIMITER + path; + } + uriVariables.put("basePath", path == null ? "" : path); + uriVariables.put("baseUrl", uriComponents.toUriString()); + uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); + uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); + + return UriComponentsBuilder.fromUriString(template) + .buildAndExpand(uriVariables) + .toUriString(); + } + + private static String getApplicationUri(HttpServletRequest request) { + UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request)) + .replacePath(request.getContextPath()) + .replaceQuery(null) + .fragment(null) + .build(); + return uriComponents.toUriString(); + } + + private static class RegistrationIdResolver implements Converter { + private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}"); + + @Override + public String convert(HttpServletRequest request) { + RequestMatcher.MatchResult result = this.requestMatcher.matcher(request); + return result.getVariables().get("registrationId"); + } + } +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java new file mode 100644 index 0000000000..693075f803 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; + +/** + * Tests for {@link DefaultRelyingPartyRegistrationResolver} + */ +public class DefaultRelyingPartyRegistrationResolverTests { + private final RelyingPartyRegistration registration = relyingPartyRegistration().build(); + private final RelyingPartyRegistrationRepository repository = + new InMemoryRelyingPartyRegistrationRepository(this.registration); + private final DefaultRelyingPartyRegistrationResolver resolver = + new DefaultRelyingPartyRegistrationResolver(this.repository); + + @Test + public void resolveWhenRequestContainsRegistrationIdThenResolves() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/some/path/" + this.registration.getRegistrationId()); + RelyingPartyRegistration registration = this.resolver.convert(request); + assertThat(registration).isNotNull(); + assertThat(registration.getRegistrationId()) + .isEqualTo(this.registration.getRegistrationId()); + assertThat(registration.getEntityId()) + .isEqualTo("http://localhost/saml2/service-provider-metadata/" + this.registration.getRegistrationId()); + assertThat(registration.getAssertionConsumerServiceLocation()) + .isEqualTo("http://localhost/login/saml2/sso/" + this.registration.getRegistrationId()); + } + + @Test + public void resolveWhenRequestContainsInvalidRegistrationIdThenNull() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/some/path/not-" + this.registration.getRegistrationId()); + RelyingPartyRegistration registration = this.resolver.convert(request); + assertThat(registration).isNull(); + } + + @Test + public void resolveWhenRequestIsMissingRegistrationIdThenNull() { + MockHttpServletRequest request = new MockHttpServletRequest(); + RelyingPartyRegistration registration = this.resolver.convert(request); + assertThat(registration).isNull(); + } + + @Test + public void constructorWhenNullRelyingPartyRegistrationThenIllegalArgument() { + assertThatCode(() -> new DefaultRelyingPartyRegistrationResolver(null)) + .isInstanceOf(IllegalArgumentException.class); + } +}