diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java index 0915607b2f..3b93c84353 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java @@ -21,11 +21,14 @@ import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.time.Clock; import java.time.Instant; -import java.util.Collection; +import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.UUID; +import net.shibboleth.utilities.java.support.resolver.CriteriaSet; import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.joda.time.DateTime; import org.opensaml.core.config.ConfigurationService; @@ -37,15 +40,18 @@ import org.opensaml.saml.saml2.core.Issuer; import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder; import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller; import org.opensaml.saml.saml2.core.impl.IssuerBuilder; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; import org.opensaml.security.SecurityException; import org.opensaml.security.credential.BasicCredential; import org.opensaml.security.credential.Credential; import org.opensaml.security.credential.CredentialSupport; import org.opensaml.security.credential.UsageType; import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; import org.opensaml.xmlsec.signature.support.SignatureConstants; -import org.opensaml.xmlsec.signature.support.SignatureException; import org.opensaml.xmlsec.signature.support.SignatureSupport; import org.w3c.dom.Element; @@ -58,6 +64,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriUtils; /** @@ -105,9 +112,17 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication request.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(null)); for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) { if (credential.isSigningCredential()) { - Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), - request.getIssuer()); - return serialize(sign(authnRequest, cred)); + X509Certificate certificate = credential.getCertificate(); + PrivateKey privateKey = credential.getPrivateKey(); + BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey); + cred.setEntityId(request.getIssuer()); + cred.setUsageType(UsageType.SIGNING); + SignatureSigningParameters parameters = new SignatureSigningParameters(); + parameters.setSigningCredential(cred); + parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256); + parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS); + return serialize(sign(authnRequest, parameters)); } } throw new IllegalArgumentException("No signing credential provided"); @@ -132,16 +147,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState()); if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) { - Collection signingCredentials = context.getRelyingPartyRegistration() - .getSigningX509Credentials(); - for (Saml2X509Credential credential : signingCredentials) { - Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), ""); - Map signedParams = signQueryParameters(cred, deflatedAndEncoded, - context.getRelayState()); - return result.samlRequest(signedParams.get("SAMLRequest")).relayState(signedParams.get("RelayState")) - .sigAlg(signedParams.get("SigAlg")).signature(signedParams.get("Signature")).build(); + Map parameters = new LinkedHashMap<>(); + parameters.put("SAMLRequest", deflatedAndEncoded); + if (StringUtils.hasText(context.getRelayState())) { + parameters.put("RelayState", context.getRelayState()); } - throw new Saml2Exception("No signing credential provided"); + sign(parameters, context.getRelyingPartyRegistration()); + return result.sigAlg(parameters.get("SigAlg")).signature(parameters.get("Signature")).build(); } return result.build(); } @@ -211,59 +223,39 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication } private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) { - for (Saml2X509Credential credential : relyingPartyRegistration.getSigningX509Credentials()) { - Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), - relyingPartyRegistration.getEntityId()); - return sign(authnRequest, cred); - } - throw new IllegalArgumentException("No signing credential provided"); + SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration); + return sign(authnRequest, parameters); } - private AuthnRequest sign(AuthnRequest authnRequest, Credential credential) { - SignatureSigningParameters parameters = new SignatureSigningParameters(); - parameters.setSigningCredential(credential); - parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); - parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256); - parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS); + private AuthnRequest sign(AuthnRequest authnRequest, SignatureSigningParameters parameters) { try { SignatureSupport.signObject(authnRequest, parameters); return authnRequest; } - catch (MarshallingException | SignatureException | SecurityException ex) { + catch (Exception ex) { throw new Saml2Exception(ex); } } - private Credential getSigningCredential(X509Certificate certificate, PrivateKey privateKey, String entityId) { - BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey); - cred.setEntityId(entityId); - cred.setUsageType(UsageType.SIGNING); - return cred; + private void sign(Map components, RelyingPartyRegistration relyingPartyRegistration) { + SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration); + sign(components, parameters); } - private Map signQueryParameters(Credential credential, String samlRequest, String relayState) { - Assert.notNull(samlRequest, "samlRequest cannot be null"); - String algorithmUri = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256; - StringBuilder queryString = new StringBuilder(); - queryString.append("SAMLRequest").append("=").append(UriUtils.encode(samlRequest, StandardCharsets.ISO_8859_1)) - .append("&"); - if (StringUtils.hasText(relayState)) { - queryString.append("RelayState").append("=") - .append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1)).append("&"); + private void sign(Map components, SignatureSigningParameters parameters) { + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + components.put("SigAlg", algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : components.entrySet()) { + builder.queryParam(component.getKey(), UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); } - queryString.append("SigAlg").append("=").append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1)); + String queryString = builder.build(true).toString().substring(1); try { byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, - queryString.toString().getBytes(StandardCharsets.UTF_8)); + queryString.getBytes(StandardCharsets.UTF_8)); String b64Signature = Saml2Utils.samlEncode(rawSignature); - Map result = new LinkedHashMap<>(); - result.put("SAMLRequest", samlRequest); - if (StringUtils.hasText(relayState)) { - result.put("RelayState", relayState); - } - result.put("SigAlg", algorithmUri); - result.put("Signature", b64Signature); - return result; + components.put("Signature", b64Signature); } catch (SecurityException ex) { throw new Saml2Exception(ex); @@ -280,4 +272,40 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication } } + private SignatureSigningParameters resolveSigningParameters(RelyingPartyRegistration relyingPartyRegistration) { + List credentials = resolveSigningCredentials(relyingPartyRegistration); + List algorithms = Collections.singletonList(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + CriteriaSet criteria = new CriteriaSet(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(algorithms); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private List resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setEntityId(relyingPartyRegistration.getEntityId()); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java index 8ddb48ef79..5b378808b9 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java @@ -26,16 +26,20 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller; +import org.opensaml.xmlsec.signature.support.SignatureConstants; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -110,6 +114,28 @@ public class OpenSamlAuthenticationRequestFactoryTests { assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); } + @Test + public void createRedirectAuthenticationRequestWhenSignRequestThenSignatureIsPresent() { + this.context = this.contextBuilder.relayState("Relay State Value") + .relyingPartyRegistration(this.relyingPartyRegistration).build(); + Saml2RedirectAuthenticationRequest request = this.factory.createRedirectAuthenticationRequest(this.context); + assertThat(request.getRelayState()).isEqualTo("Relay State Value"); + assertThat(request.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + assertThat(request.getSignature()).isNotNull(); + } + + @Test + public void createRedirectAuthenticationRequestWhenSignRequestThenCredentialIsRequired() { + Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials + .relyingPartyVerifyingCredential(); + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials() + .assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build(); + this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration) + .build(); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context)); + } + @Test public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() { this.context = this.contextBuilder.relayState("Relay State Value") @@ -139,6 +165,18 @@ public class OpenSamlAuthenticationRequestFactoryTests { .contains("ds:Signature"); } + @Test + public void createPostAuthenticationRequestWhenSignRequestThenCredentialIsRequired() { + Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials + .relyingPartyVerifyingCredential(); + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials() + .assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build(); + this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration) + .build(); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context)); + } + @Test public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() { AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);