diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java index 2cd7a23064..b9fb1e451e 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java @@ -37,6 +37,7 @@ import org.opensaml.saml.saml2.core.impl.IssuerBuilder; import org.opensaml.saml.saml2.core.impl.NameIDBuilder; import org.w3c.dom.Element; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2ParameterNames; @@ -72,6 +73,8 @@ class OpenSamlAuthenticationRequestResolver { private final NameIDBuilder nameIdBuilder; + private Converter relayStateResolver = (request) -> UUID.randomUUID().toString(); + /** * Construct a {@link OpenSamlAuthenticationRequestResolver} using the provided * parameters @@ -94,6 +97,10 @@ class OpenSamlAuthenticationRequestResolver { Assert.notNull(this.nameIdBuilder, "nameIdBuilder must be configured in OpenSAML"); } + void setRelayStateResolver(Converter relayStateResolver) { + this.relayStateResolver = relayStateResolver; + } + T resolve(HttpServletRequest request) { return resolve(request, (registration, logoutRequest) -> { }); @@ -123,7 +130,7 @@ class OpenSamlAuthenticationRequestResolver { if (authnRequest.getID() == null) { authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1)); } - String relayState = UUID.randomUUID().toString(); + String relayState = this.relayStateResolver.convert(request); Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleSignOnServiceBinding(); if (binding == Saml2MessageBinding.POST) { if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) { diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java index a117d8d845..6269373fab 100644 --- a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java +++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java @@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletRequest; import org.opensaml.saml.saml2.core.AuthnRequest; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; @@ -78,6 +79,16 @@ public final class OpenSaml4AuthenticationRequestResolver implements Saml2Authen this.clock = clock; } + /** + * Use this {@link Converter} to compute the RelayState + * @param relayStateResolver the {@link Converter} to use + * @since 5.7 + */ + public void setRelayStateResolver(Converter relayStateResolver) { + Assert.notNull(relayStateResolver, "relayStateResolver cannot be null"); + this.authnRequestResolver.setRelayStateResolver(relayStateResolver); + } + public static final class AuthnRequestContext { private final HttpServletRequest request;