diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java index 2088f55cf2..705acf6c0f 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java @@ -16,15 +16,15 @@ package org.springframework.security.saml2.provider.service.servlet.filter; +import java.util.HashMap; +import java.util.Map; +import javax.servlet.http.HttpServletRequest; + import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; -import javax.servlet.http.HttpServletRequest; -import java.util.HashMap; -import java.util.Map; - import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl; import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl; @@ -35,20 +35,13 @@ final class Saml2ServletUtils { private static final char PATH_DELIMITER = '/'; - static String getServiceProviderEntityId(RelyingPartyRegistration rp, HttpServletRequest request) { - return resolveUrlTemplate( - rp.getLocalEntityIdTemplate(), - getApplicationUri(request), - rp.getProviderDetails().getEntityId(), - rp.getRegistrationId() - ); - } - - static String resolveUrlTemplate(String template, String baseUrl, String entityId, String registrationId) { + static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { if (!StringUtils.hasText(template)) { return baseUrl; } + String entityId = relyingParty.getProviderDetails().getEntityId(); + String registrationId = relyingParty.getRegistrationId(); Map uriVariables = new HashMap<>(); UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl) .replaceQuery(null) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java index a332664be2..4fb2265aa1 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java @@ -16,6 +16,9 @@ package org.springframework.security.saml2.provider.service.servlet.filter; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.http.HttpMethod; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -30,9 +33,6 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import static java.nio.charset.StandardCharsets.UTF_8; import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND; import static org.springframework.util.StringUtils.hasText; @@ -97,7 +97,8 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce "Relying Party Registration not found with ID: " + registrationId); throw new Saml2AuthenticationException(saml2Error); } - String localSpEntityId = Saml2ServletUtils.getServiceProviderEntityId(rp, request); + String applicationUri = Saml2ServletUtils.getApplicationUri(request); + String localSpEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getLocalEntityIdTemplate(), applicationUri, rp); final Saml2AuthenticationToken authentication = new Saml2AuthenticationToken( responseXml, request.getRequestURL().toString(), diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java index 03c628d2d6..be7a13f47a 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java @@ -16,6 +16,13 @@ package org.springframework.security.saml2.provider.service.servlet.filter; +import java.io.IOException; +import java.util.function.Function; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.http.MediaType; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; @@ -34,12 +41,6 @@ import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriUtils; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; - import static java.lang.String.format; import static java.nio.charset.StandardCharsets.ISO_8859_1; import static org.springframework.util.StringUtils.hasText; @@ -137,22 +138,20 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext( RelyingPartyRegistration relyingParty, HttpServletRequest request) { - String localSpEntityId = Saml2ServletUtils.getServiceProviderEntityId(relyingParty, request); - return Saml2AuthenticationRequestContext - .builder() + String applicationUri = Saml2ServletUtils.getApplicationUri(request); + Function resolver = templateResolver(applicationUri, relyingParty); + String localSpEntityId = resolver.apply(relyingParty.getLocalEntityIdTemplate()); + String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceUrlTemplate()); + return Saml2AuthenticationRequestContext.builder() .issuer(localSpEntityId) .relyingPartyRegistration(relyingParty) - .assertionConsumerServiceUrl( - Saml2ServletUtils.resolveUrlTemplate( - relyingParty.getAssertionConsumerServiceUrlTemplate(), - Saml2ServletUtils.getApplicationUri(request), - relyingParty.getProviderDetails().getEntityId(), - relyingParty.getRegistrationId() - ) - ) + .assertionConsumerServiceUrl(assertionConsumerServiceUrl) .relayState(request.getParameter("RelayState")) - .build() - ; + .build(); + } + + private Function templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) { + return template -> Saml2ServletUtils.resolveUrlTemplate(template, applicationUri, relyingParty); } private String htmlEscape(String value) {