From bf5b334531d64ec75309b47166c1fdc11c8ac192 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Mon, 5 Aug 2024 08:56:18 -0600 Subject: [PATCH] Use OpenSAML API for web.authentication Issue gh-11658 --- ...ing-security-saml2-service-provider.gradle | 7 + ...penSamlAuthenticationRequestResolver.java} | 103 ++- ...penSaml4AuthenticationRequestResolver.java | 24 +- .../web/authentication/OpenSaml4Template.java | 617 ++++++++++++++++++ .../authentication/OpenSamlOperations.java | 184 ++++++ .../authentication/OpenSamlSigningUtils.java | 193 ------ .../OpenSamlVerificationUtils.java | 206 ------ .../web/authentication/Saml2Utils.java | 122 +++- ...s.java => OpenSaml4SigningUtilsTests.java} | 6 +- ...amlAuthenticationRequestResolverTests.java | 270 -------- 10 files changed, 1016 insertions(+), 716 deletions(-) rename saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/{OpenSamlAuthenticationRequestResolver.java => BaseOpenSamlAuthenticationRequestResolver.java} (77%) create mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4Template.java create mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlOperations.java delete mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtils.java delete mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlVerificationUtils.java rename saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/{OpenSamlSigningUtilsTests.java => OpenSaml4SigningUtilsTests.java} (93%) delete mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index a2576d297b..180096cc02 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -31,6 +31,13 @@ sourceSets.configureEach { set -> filter { line -> line.replaceAll(".saml2.internal", ".saml2.provider.service.web.authentication.logout") } with from } + + copy { + into "$projectDir/src/$set.name/java/org/springframework/security/saml2/provider/service/web/authentication" + filter { line -> line.replaceAll(".saml2.internal", ".saml2.provider.service.web.authentication") } + with from + } + } dependencies { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/BaseOpenSamlAuthenticationRequestResolver.java similarity index 77% rename from saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java rename to saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/BaseOpenSamlAuthenticationRequestResolver.java index 2ec777a70e..1a19bacb60 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/BaseOpenSamlAuthenticationRequestResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,19 +16,19 @@ package org.springframework.security.saml2.provider.service.web.authentication; -import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.Instant; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; -import java.util.function.BiConsumer; +import java.util.function.Consumer; import jakarta.servlet.http.HttpServletRequest; -import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.xml.config.XMLObjectProviderRegistry; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; -import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.Issuer; import org.opensaml.saml.saml2.core.NameID; @@ -38,10 +38,8 @@ import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller; import org.opensaml.saml.saml2.core.impl.IssuerBuilder; import org.opensaml.saml.saml2.core.impl.NameIDBuilder; import org.opensaml.saml.saml2.core.impl.NameIDPolicyBuilder; -import org.w3c.dom.Element; import org.springframework.core.convert.converter.Converter; -import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2ParameterNames; import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; @@ -63,11 +61,14 @@ import org.springframework.util.Assert; * For internal use only. Intended for consolidating common behavior related to minting a * SAML 2.0 Authn Request. */ -class OpenSamlAuthenticationRequestResolver { +class BaseOpenSamlAuthenticationRequestResolver implements Saml2AuthenticationRequestResolver { static { OpenSamlInitializationService.initialize(); } + + private final OpenSamlOperations saml; + private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; private final AuthnRequestBuilder authnRequestBuilder; @@ -84,15 +85,22 @@ class OpenSamlAuthenticationRequestResolver { new AntPathRequestMatcher(Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI), new AntPathQueryRequestMatcher("/saml2/authenticate", "registrationId={registrationId}")); + private Clock clock = Clock.systemUTC(); + private Converter relayStateResolver = (request) -> UUID.randomUUID().toString(); + private Consumer parametersConsumer = (parameters) -> { + }; + /** - * Construct a {@link OpenSamlAuthenticationRequestResolver} using the provided + * Construct a {@link BaseOpenSamlAuthenticationRequestResolver} using the provided * parameters * @param relyingPartyRegistrationResolver a strategy for resolving the * {@link RelyingPartyRegistration} from the {@link HttpServletRequest} */ - OpenSamlAuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + BaseOpenSamlAuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver, + OpenSamlOperations saml) { + this.saml = saml; Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); @@ -111,6 +119,10 @@ class OpenSamlAuthenticationRequestResolver { Assert.notNull(this.nameIdPolicyBuilder, "nameIdPolicyBuilder must be configured in OpenSAML"); } + void setClock(Clock clock) { + this.clock = clock; + } + void setRelayStateResolver(Converter relayStateResolver) { this.relayStateResolver = relayStateResolver; } @@ -119,13 +131,12 @@ class OpenSamlAuthenticationRequestResolver { this.requestMatcher = requestMatcher; } - T resolve(HttpServletRequest request) { - return resolve(request, (registration, logoutRequest) -> { - }); + void setParametersConsumer(Consumer parametersConsumer) { + this.parametersConsumer = parametersConsumer; } - T resolve(HttpServletRequest request, - BiConsumer authnRequestConsumer) { + @Override + public T resolve(HttpServletRequest request) { RequestMatcher.MatchResult result = this.requestMatcher.matcher(request); if (!result.isMatch()) { return null; @@ -153,7 +164,8 @@ class OpenSamlAuthenticationRequestResolver { nameIdPolicy.setFormat(registration.getNameIdFormat()); authnRequest.setNameIDPolicy(nameIdPolicy); } - authnRequestConsumer.accept(registration, authnRequest); + authnRequest.setIssueInstant(Instant.now(this.clock)); + this.parametersConsumer.accept(new AuthnRequestParameters(request, registration, authnRequest)); if (authnRequest.getID() == null) { authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1)); } @@ -162,10 +174,12 @@ class OpenSamlAuthenticationRequestResolver { if (binding == Saml2MessageBinding.POST) { if (registration.getAssertingPartyMetadata().getWantAuthnRequestsSigned() || registration.isAuthnRequestsSigned()) { - OpenSamlSigningUtils.sign(authnRequest, registration); + this.saml.withSigningKeys(registration.getSigningX509Credentials()) + .algorithms(registration.getAssertingPartyMetadata().getSigningAlgorithms()) + .sign(authnRequest); } String xml = serialize(authnRequest); - String encoded = Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8)); + String encoded = Saml2Utils.withDecoded(xml).encode(); return (T) Saml2PostAuthenticationRequest.withRelyingPartyRegistration(registration) .samlRequest(encoded) .relayState(relayState) @@ -174,7 +188,7 @@ class OpenSamlAuthenticationRequestResolver { } else { String xml = serialize(authnRequest); - String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); + String deflatedAndEncoded = Saml2Utils.withDecoded(xml).deflate(true).encode(); Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest .withRelyingPartyRegistration(registration) .samlRequest(deflatedAndEncoded) @@ -182,27 +196,23 @@ class OpenSamlAuthenticationRequestResolver { .id(authnRequest.getID()); if (registration.getAssertingPartyMetadata().getWantAuthnRequestsSigned() || registration.isAuthnRequestsSigned()) { - OpenSamlSigningUtils.QueryParametersPartial parametersPartial = OpenSamlSigningUtils.sign(registration) - .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded); + Map signingParameters = new HashMap<>(); + signingParameters.put(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded); if (relayState != null) { - parametersPartial = parametersPartial.param(Saml2ParameterNames.RELAY_STATE, relayState); + signingParameters.put(Saml2ParameterNames.RELAY_STATE, relayState); } - Map parameters = parametersPartial.parameters(); - builder.sigAlg(parameters.get(Saml2ParameterNames.SIG_ALG)) - .signature(parameters.get(Saml2ParameterNames.SIGNATURE)); + Map query = this.saml.withSigningKeys(registration.getSigningX509Credentials()) + .algorithms(registration.getAssertingPartyMetadata().getSigningAlgorithms()) + .sign(signingParameters); + builder.sigAlg(query.get(Saml2ParameterNames.SIG_ALG)) + .signature(query.get(Saml2ParameterNames.SIGNATURE)); } return (T) builder.build(); } } private String serialize(AuthnRequest authnRequest) { - try { - Element element = this.marshaller.marshall(authnRequest); - return SerializeSupport.nodeToString(element); - } - catch (MarshallingException ex) { - throw new Saml2Exception(ex); - } + return this.saml.serialize(authnRequest).serialize(); } private static final class AntPathQueryRequestMatcher implements RequestMatcher { @@ -236,4 +246,33 @@ class OpenSamlAuthenticationRequestResolver { } + static final class AuthnRequestParameters { + + private final HttpServletRequest request; + + private final RelyingPartyRegistration registration; + + private final AuthnRequest authnRequest; + + AuthnRequestParameters(HttpServletRequest request, RelyingPartyRegistration registration, + AuthnRequest authnRequest) { + this.request = request; + this.registration = registration; + this.authnRequest = authnRequest; + } + + HttpServletRequest getRequest() { + return this.request; + } + + RelyingPartyRegistration getRelyingPartyRegistration() { + return this.registration; + } + + AuthnRequest getAuthnRequest() { + return this.authnRequest; + } + + } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java index e0c434e21c..9d9cb41215 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java @@ -40,7 +40,7 @@ import org.springframework.util.Assert; */ public final class OpenSaml4AuthenticationRequestResolver implements Saml2AuthenticationRequestResolver { - private final OpenSamlAuthenticationRequestResolver authnRequestResolver; + private final BaseOpenSamlAuthenticationRequestResolver delegate; private Consumer contextConsumer = (parameters) -> { }; @@ -53,27 +53,25 @@ public final class OpenSaml4AuthenticationRequestResolver implements Saml2Authen * @since 6.1 */ public OpenSaml4AuthenticationRequestResolver(RelyingPartyRegistrationRepository registrations) { - this.authnRequestResolver = new OpenSamlAuthenticationRequestResolver((request, id) -> { + this.delegate = new BaseOpenSamlAuthenticationRequestResolver((request, id) -> { if (id == null) { return null; } return registrations.findByRegistrationId(id); - }); + }, new OpenSaml4Template()); } /** * Construct a {@link OpenSaml4AuthenticationRequestResolver} */ public OpenSaml4AuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { - this.authnRequestResolver = new OpenSamlAuthenticationRequestResolver(relyingPartyRegistrationResolver); + this.delegate = new BaseOpenSamlAuthenticationRequestResolver(relyingPartyRegistrationResolver, + new OpenSaml4Template()); } @Override public T resolve(HttpServletRequest request) { - return this.authnRequestResolver.resolve(request, (registration, authnRequest) -> { - authnRequest.setIssueInstant(Instant.now(this.clock)); - this.contextConsumer.accept(new AuthnRequestContext(request, registration, authnRequest)); - }); + return this.delegate.resolve(request); } /** @@ -82,12 +80,14 @@ public final class OpenSaml4AuthenticationRequestResolver implements Saml2Authen */ public void setAuthnRequestCustomizer(Consumer contextConsumer) { Assert.notNull(contextConsumer, "contextConsumer cannot be null"); - this.contextConsumer = contextConsumer; + this.delegate.setParametersConsumer( + (parameters) -> contextConsumer.accept(new AuthnRequestContext(parameters.getRequest(), + parameters.getRelyingPartyRegistration(), parameters.getAuthnRequest()))); } /** * Set the {@link RequestMatcher} to use for setting the - * {@link OpenSamlAuthenticationRequestResolver#setRequestMatcher(RequestMatcher)} + * {@link BaseOpenSamlAuthenticationRequestResolver#setRequestMatcher(RequestMatcher)} * (RequestMatcher)} * @param requestMatcher the {@link RequestMatcher} to identify authentication * requests. @@ -95,7 +95,7 @@ public final class OpenSaml4AuthenticationRequestResolver implements Saml2Authen */ public void setRequestMatcher(RequestMatcher requestMatcher) { Assert.notNull(requestMatcher, "requestMatcher cannot be null"); - this.authnRequestResolver.setRequestMatcher(requestMatcher); + this.delegate.setRequestMatcher(requestMatcher); } /** @@ -114,7 +114,7 @@ public final class OpenSaml4AuthenticationRequestResolver implements Saml2Authen */ public void setRelayStateResolver(Converter relayStateResolver) { Assert.notNull(relayStateResolver, "relayStateResolver cannot be null"); - this.authnRequestResolver.setRelayStateResolver(relayStateResolver); + this.delegate.setRelayStateResolver(relayStateResolver); } public static final class AuthnRequestContext { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4Template.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4Template.java new file mode 100644 index 0000000000..9ca1253379 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.utilities.java.support.resolver.CriteriaSet; +import net.shibboleth.utilities.java.support.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml4Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml4Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml4SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml4SerializationConfigurer serialize(Element element) { + return new OpenSaml4SerializationConfigurer(element); + } + + @Override + public OpenSaml4SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml4SignatureConfigurer(credentials); + } + + @Override + public OpenSaml4VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml4VerificationConfigurer(credentials); + } + + @Override + public OpenSaml4DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml4DecryptionConfigurer(credentials); + } + + OpenSaml4Template() { + + } + + static final class OpenSaml4SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml4SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml4SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml4SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml4SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml4SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml4VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml4VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml4DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml4DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlOperations.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlOperations.java new file mode 100644 index 0000000000..7fb091a8be --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlOperations.java @@ -0,0 +1,184 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import javax.xml.namespace.QName; + +import org.opensaml.core.xml.XMLObject; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.web.util.UriComponentsBuilder; + +interface OpenSamlOperations { + + T build(QName elementName); + + T deserialize(String serialized); + + T deserialize(InputStream serialized); + + SerializationConfigurer serialize(XMLObject object); + + SerializationConfigurer serialize(Element element); + + SignatureConfigurer withSigningKeys(Collection credentials); + + VerificationConfigurer withVerificationKeys(Collection credentials); + + DecryptionConfigurer withDecryptionKeys(Collection credentials); + + interface SerializationConfigurer> { + + B prettyPrint(boolean pretty); + + String serialize(); + + } + + interface SignatureConfigurer> { + + B algorithms(List algs); + + O sign(O object); + + Map sign(Map params); + + } + + interface VerificationConfigurer { + + VerificationConfigurer entityId(String entityId); + + Collection verify(SignableXMLObject signable); + + Collection verify(VerificationConfigurer.RedirectParameters parameters); + + final class RedirectParameters { + + private final String id; + + private final Issuer issuer; + + private final String algorithm; + + private final byte[] signature; + + private final byte[] content; + + RedirectParameters(Map parameters, String parametersQuery, RequestAbstractType request) { + this.id = request.getID(); + this.issuer = request.getIssuer(); + this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG); + if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) { + this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE)); + } + else { + this.signature = null; + } + Map queryParams = UriComponentsBuilder.newInstance() + .query(parametersQuery) + .build(true) + .getQueryParams() + .toSingleValueMap(); + String relayState = parameters.get(Saml2ParameterNames.RELAY_STATE); + this.content = getContent(Saml2ParameterNames.SAML_REQUEST, relayState, queryParams); + } + + RedirectParameters(Map parameters, String parametersQuery, StatusResponseType response) { + this.id = response.getID(); + this.issuer = response.getIssuer(); + this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG); + if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) { + this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE)); + } + else { + this.signature = null; + } + Map queryParams = UriComponentsBuilder.newInstance() + .query(parametersQuery) + .build(true) + .getQueryParams() + .toSingleValueMap(); + String relayState = parameters.get(Saml2ParameterNames.RELAY_STATE); + this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams); + } + + static byte[] getContent(String samlObject, String relayState, final Map queryParams) { + if (Objects.nonNull(relayState)) { + return String + .format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject), + Saml2ParameterNames.RELAY_STATE, queryParams.get(Saml2ParameterNames.RELAY_STATE), + Saml2ParameterNames.SIG_ALG, queryParams.get(Saml2ParameterNames.SIG_ALG)) + .getBytes(StandardCharsets.UTF_8); + } + else { + return String + .format("%s=%s&%s=%s", samlObject, queryParams.get(samlObject), Saml2ParameterNames.SIG_ALG, + queryParams.get(Saml2ParameterNames.SIG_ALG)) + .getBytes(StandardCharsets.UTF_8); + } + } + + String getId() { + return this.id; + } + + Issuer getIssuer() { + return this.issuer; + } + + byte[] getContent() { + return this.content; + } + + String getAlgorithm() { + return this.algorithm; + } + + byte[] getSignature() { + return this.signature; + } + + boolean hasSignature() { + return this.signature != null; + } + + } + + } + + interface DecryptionConfigurer { + + void decrypt(XMLObject object); + + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtils.java deleted file mode 100644 index c479b06cf5..0000000000 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtils.java +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.web.authentication; - -import java.nio.charset.StandardCharsets; -import java.security.PrivateKey; -import java.security.cert.X509Certificate; -import java.util.ArrayList; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -import net.shibboleth.utilities.java.support.resolver.CriteriaSet; -import net.shibboleth.utilities.java.support.xml.SerializeSupport; -import org.opensaml.core.xml.XMLObject; -import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; -import org.opensaml.core.xml.io.Marshaller; -import org.opensaml.core.xml.io.MarshallingException; -import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; -import org.opensaml.security.SecurityException; -import org.opensaml.security.credential.BasicCredential; -import org.opensaml.security.credential.Credential; -import org.opensaml.security.credential.CredentialSupport; -import org.opensaml.security.credential.UsageType; -import org.opensaml.xmlsec.SignatureSigningParameters; -import org.opensaml.xmlsec.SignatureSigningParametersResolver; -import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; -import org.opensaml.xmlsec.crypto.XMLSigningUtil; -import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; -import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; -import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; -import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; -import org.opensaml.xmlsec.signature.SignableXMLObject; -import org.opensaml.xmlsec.signature.support.SignatureConstants; -import org.opensaml.xmlsec.signature.support.SignatureSupport; -import org.w3c.dom.Element; - -import org.springframework.security.saml2.Saml2Exception; -import org.springframework.security.saml2.core.Saml2X509Credential; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.util.Assert; -import org.springframework.web.util.UriComponentsBuilder; -import org.springframework.web.util.UriUtils; - -/** - * Utility methods for signing SAML components with OpenSAML - * - * For internal use only. - * - * @author Josh Cummings - */ -final class OpenSamlSigningUtils { - - static String serialize(XMLObject object) { - try { - Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); - Element element = marshaller.marshall(object); - return SerializeSupport.nodeToString(element); - } - catch (MarshallingException ex) { - throw new Saml2Exception(ex); - } - } - - static O sign(O object, RelyingPartyRegistration relyingPartyRegistration) { - SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration); - try { - SignatureSupport.signObject(object, parameters); - return object; - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - - static QueryParametersPartial sign(RelyingPartyRegistration registration) { - return new QueryParametersPartial(registration); - } - - private static SignatureSigningParameters resolveSigningParameters( - RelyingPartyRegistration relyingPartyRegistration) { - List credentials = resolveSigningCredentials(relyingPartyRegistration); - List algorithms = relyingPartyRegistration.getAssertingPartyMetadata().getSigningAlgorithms(); - List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); - String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; - SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); - CriteriaSet criteria = new CriteriaSet(); - BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); - signingConfiguration.setSigningCredentials(credentials); - signingConfiguration.setSignatureAlgorithms(algorithms); - signingConfiguration.setSignatureReferenceDigestMethods(digests); - signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); - signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); - criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration)); - try { - SignatureSigningParameters parameters = resolver.resolveSingle(criteria); - Assert.notNull(parameters, "Failed to resolve any signing credential"); - return parameters; - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - - private static NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { - final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); - - namedManager.setUseDefaultManager(true); - final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); - - // Generator for X509Credentials - final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); - x509Factory.setEmitEntityCertificate(true); - x509Factory.setEmitEntityCertificateChain(true); - - defaultManager.registerFactory(x509Factory); - - return namedManager; - } - - private static List resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) { - List credentials = new ArrayList<>(); - for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) { - X509Certificate certificate = x509Credential.getCertificate(); - PrivateKey privateKey = x509Credential.getPrivateKey(); - BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); - credential.setEntityId(relyingPartyRegistration.getEntityId()); - credential.setUsageType(UsageType.SIGNING); - credentials.add(credential); - } - return credentials; - } - - private OpenSamlSigningUtils() { - - } - - static class QueryParametersPartial { - - final RelyingPartyRegistration registration; - - final Map components = new LinkedHashMap<>(); - - QueryParametersPartial(RelyingPartyRegistration registration) { - this.registration = registration; - } - - QueryParametersPartial param(String key, String value) { - this.components.put(key, value); - return this; - } - - Map parameters() { - SignatureSigningParameters parameters = resolveSigningParameters(this.registration); - Credential credential = parameters.getSigningCredential(); - String algorithmUri = parameters.getSignatureAlgorithm(); - this.components.put("SigAlg", algorithmUri); - UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); - for (Map.Entry component : this.components.entrySet()) { - builder.queryParam(component.getKey(), - UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); - } - String queryString = builder.build(true).toString().substring(1); - try { - byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, - queryString.getBytes(StandardCharsets.UTF_8)); - String b64Signature = Saml2Utils.samlEncode(rawSignature); - this.components.put("Signature", b64Signature); - } - catch (SecurityException ex) { - throw new Saml2Exception(ex); - } - return this.components; - } - - } - -} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlVerificationUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlVerificationUtils.java deleted file mode 100644 index 7d52867e4d..0000000000 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlVerificationUtils.java +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.web.authentication; - -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.Set; - -import jakarta.servlet.http.HttpServletRequest; -import net.shibboleth.utilities.java.support.resolver.CriteriaSet; -import org.opensaml.core.criterion.EntityIdCriterion; -import org.opensaml.saml.common.xml.SAMLConstants; -import org.opensaml.saml.criterion.ProtocolCriterion; -import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; -import org.opensaml.saml.saml2.core.Issuer; -import org.opensaml.saml.saml2.core.RequestAbstractType; -import org.opensaml.saml.saml2.core.StatusResponseType; -import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; -import org.opensaml.security.credential.Credential; -import org.opensaml.security.credential.CredentialResolver; -import org.opensaml.security.credential.UsageType; -import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; -import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; -import org.opensaml.security.credential.impl.CollectionCredentialResolver; -import org.opensaml.security.criteria.UsageCriterion; -import org.opensaml.security.x509.BasicX509Credential; -import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; -import org.opensaml.xmlsec.signature.Signature; -import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; -import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; - -import org.springframework.security.saml2.core.Saml2Error; -import org.springframework.security.saml2.core.Saml2ErrorCodes; -import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; -import org.springframework.security.saml2.core.Saml2X509Credential; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.web.util.UriUtils; - -/** - * Utility methods for verifying SAML component signatures with OpenSAML - * - * For internal use only. - * - * @author Josh Cummings - */ - -final class OpenSamlVerificationUtils { - - static VerifierPartial verifySignature(StatusResponseType object, RelyingPartyRegistration registration) { - return new VerifierPartial(object, registration); - } - - static VerifierPartial verifySignature(RequestAbstractType object, RelyingPartyRegistration registration) { - return new VerifierPartial(object, registration); - } - - private OpenSamlVerificationUtils() { - - } - - static class VerifierPartial { - - private final String id; - - private final CriteriaSet criteria; - - private final SignatureTrustEngine trustEngine; - - VerifierPartial(StatusResponseType object, RelyingPartyRegistration registration) { - this.id = object.getID(); - this.criteria = verificationCriteria(object.getIssuer()); - this.trustEngine = trustEngine(registration); - } - - VerifierPartial(RequestAbstractType object, RelyingPartyRegistration registration) { - this.id = object.getID(); - this.criteria = verificationCriteria(object.getIssuer()); - this.trustEngine = trustEngine(registration); - } - - Saml2ResponseValidatorResult redirect(HttpServletRequest request, String objectParameterName) { - RedirectSignature signature = new RedirectSignature(request, objectParameterName); - if (signature.getAlgorithm() == null) { - return Saml2ResponseValidatorResult.failure(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Missing signature algorithm for object [" + this.id + "]")); - } - if (!signature.hasSignature()) { - return Saml2ResponseValidatorResult.failure(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Missing signature for object [" + this.id + "]")); - } - Collection errors = new ArrayList<>(); - String algorithmUri = signature.getAlgorithm(); - try { - if (!this.trustEngine.validate(signature.getSignature(), signature.getContent(), algorithmUri, - this.criteria, null)) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for object [" + this.id + "]")); - } - } - catch (Exception ex) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for object [" + this.id + "]: ")); - } - return Saml2ResponseValidatorResult.failure(errors); - } - - Saml2ResponseValidatorResult post(Signature signature) { - Collection errors = new ArrayList<>(); - SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); - try { - profileValidator.validate(signature); - } - catch (Exception ex) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for object [" + this.id + "]: ")); - } - - try { - if (!this.trustEngine.validate(signature, this.criteria)) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for object [" + this.id + "]")); - } - } - catch (Exception ex) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for object [" + this.id + "]: ")); - } - - return Saml2ResponseValidatorResult.failure(errors); - } - - private CriteriaSet verificationCriteria(Issuer issuer) { - CriteriaSet criteria = new CriteriaSet(); - criteria.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue()))); - criteria.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS))); - criteria.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); - return criteria; - } - - private SignatureTrustEngine trustEngine(RelyingPartyRegistration registration) { - Set credentials = new HashSet<>(); - Collection keys = registration.getAssertingPartyMetadata() - .getVerificationX509Credentials(); - for (Saml2X509Credential key : keys) { - BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); - cred.setUsageType(UsageType.SIGNING); - cred.setEntityId(registration.getAssertingPartyMetadata().getEntityId()); - credentials.add(cred); - } - CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); - return new ExplicitKeySignatureTrustEngine(credentialsResolver, - DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); - } - - private static class RedirectSignature { - - private final HttpServletRequest request; - - private final String objectParameterName; - - RedirectSignature(HttpServletRequest request, String objectParameterName) { - this.request = request; - this.objectParameterName = objectParameterName; - } - - String getAlgorithm() { - return this.request.getParameter("SigAlg"); - } - - byte[] getContent() { - String query = String.format("%s=%s&SigAlg=%s", this.objectParameterName, - UriUtils.encode(this.request.getParameter(this.objectParameterName), - StandardCharsets.ISO_8859_1), - UriUtils.encode(getAlgorithm(), StandardCharsets.ISO_8859_1)); - return query.getBytes(StandardCharsets.UTF_8); - } - - byte[] getSignature() { - return Saml2Utils.samlDecode(this.request.getParameter("Signature")); - } - - boolean hasSignature() { - return this.request.getParameter("Signature") != null; - } - - } - - } - -} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2Utils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2Utils.java index 019fab46c9..17c4ffde4f 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2Utils.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2Utils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.web.authentication; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.Base64; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; @@ -73,4 +74,123 @@ final class Saml2Utils { } } + static EncodingConfigurer withDecoded(String decoded) { + return new EncodingConfigurer(decoded); + } + + static DecodingConfigurer withEncoded(String encoded) { + return new DecodingConfigurer(encoded); + } + + static final class EncodingConfigurer { + + private final String decoded; + + private boolean deflate; + + private EncodingConfigurer(String decoded) { + this.decoded = decoded; + } + + EncodingConfigurer deflate(boolean deflate) { + this.deflate = deflate; + return this; + } + + String encode() { + byte[] bytes = (this.deflate) ? Saml2Utils.samlDeflate(this.decoded) + : this.decoded.getBytes(StandardCharsets.UTF_8); + return Saml2Utils.samlEncode(bytes); + } + + } + + static final class DecodingConfigurer { + + private static final Base64Checker BASE_64_CHECKER = new Base64Checker(); + + private final String encoded; + + private boolean inflate; + + private boolean requireBase64; + + private DecodingConfigurer(String encoded) { + this.encoded = encoded; + } + + DecodingConfigurer inflate(boolean inflate) { + this.inflate = inflate; + return this; + } + + DecodingConfigurer requireBase64(boolean requireBase64) { + this.requireBase64 = requireBase64; + return this; + } + + String decode() { + if (this.requireBase64) { + BASE_64_CHECKER.checkAcceptable(this.encoded); + } + byte[] bytes = Saml2Utils.samlDecode(this.encoded); + return (this.inflate) ? Saml2Utils.samlInflate(bytes) : new String(bytes, StandardCharsets.UTF_8); + } + + static class Base64Checker { + + private static final int[] values = genValueMapping(); + + Base64Checker() { + + } + + private static int[] genValueMapping() { + byte[] alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + .getBytes(StandardCharsets.ISO_8859_1); + + int[] values = new int[256]; + Arrays.fill(values, -1); + for (int i = 0; i < alphabet.length; i++) { + values[alphabet[i] & 0xff] = i; + } + return values; + } + + boolean isAcceptable(String s) { + int goodChars = 0; + int lastGoodCharVal = -1; + + // count number of characters from Base64 alphabet + for (int i = 0; i < s.length(); i++) { + int val = values[0xff & s.charAt(i)]; + if (val != -1) { + lastGoodCharVal = val; + goodChars++; + } + } + + // in cases of an incomplete final chunk, ensure the unused bits are zero + switch (goodChars % 4) { + case 0: + return true; + case 2: + return (lastGoodCharVal & 0b1111) == 0; + case 3: + return (lastGoodCharVal & 0b11) == 0; + default: + return false; + } + } + + void checkAcceptable(String ins) { + if (!isAcceptable(ins)) { + throw new IllegalArgumentException("Failed to decode SAMLResponse"); + } + } + + } + + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtilsTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4SigningUtilsTests.java similarity index 93% rename from saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtilsTests.java rename to saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4SigningUtilsTests.java index 29ceaa02f5..1ebaf31f27 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtilsTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4SigningUtilsTests.java @@ -38,12 +38,14 @@ import static org.assertj.core.api.Assertions.assertThat; /** * Test open SAML signatures */ -public class OpenSamlSigningUtilsTests { +public class OpenSaml4SigningUtilsTests { static { OpenSamlInitializationService.initialize(); } + private final OpenSamlOperations saml = new OpenSaml4Template(); + private RelyingPartyRegistration registration; @BeforeEach @@ -62,7 +64,7 @@ public class OpenSamlSigningUtilsTests { @Test public void whenSigningAnObjectThenKeyInfoIsPartOfTheSignature() { Response response = response("destination", "issuer"); - OpenSamlSigningUtils.sign(response, this.registration); + this.saml.withSigningKeys(this.registration.getSigningX509Credentials()).sign(response); Signature signature = response.getSignature(); assertThat(signature).isNotNull(); assertThat(signature.getKeyInfo()).isNotNull(); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java deleted file mode 100644 index b4088bac63..0000000000 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java +++ /dev/null @@ -1,270 +0,0 @@ -/* - * Copyright 2002-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.web.authentication; - -import java.util.stream.Stream; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.Answers; -import org.mockito.MockedStatic; -import org.opensaml.xmlsec.signature.support.SignatureConstants; - -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.security.saml2.Saml2Exception; -import org.springframework.security.saml2.core.Saml2ParameterNames; -import org.springframework.security.saml2.core.Saml2X509Credential; -import org.springframework.security.saml2.core.TestSaml2X509Credentials; -import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; -import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; -import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers; -import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; - -/** - * Tests for {@link OpenSamlAuthenticationRequestResolver} - */ -public class OpenSamlAuthenticationRequestResolverTests { - - private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder; - - @BeforeEach - public void setUp() { - this.relyingPartyRegistrationBuilder = TestRelyingPartyRegistrations.relyingPartyRegistration(); - } - - @ParameterizedTest - @MethodSource("provideSignRequestFlags") - public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects(boolean wantAuthRequestsSigned, - boolean authnRequestsSigned) { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setPathInfo("/saml2/authenticate/registration-id"); - RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder - .authnRequestsSigned(authnRequestsSigned) - .assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(wantAuthRequestsSigned)) - .build(); - OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); - Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { - UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); - assertThat(authnRequest.getNameIDPolicy().getFormat()).isEqualTo(registration.getNameIdFormat()); - assertThat(authnRequest.getAssertionConsumerServiceURL()) - .isEqualTo(uriResolver.resolve(registration.getAssertionConsumerServiceLocation())); - assertThat(authnRequest.getProtocolBinding()) - .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn()); - assertThat(authnRequest.getDestination()) - .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); - assertThat(authnRequest.getIssuer().getValue()).isEqualTo(uriResolver.resolve(registration.getEntityId())); - }); - assertThat(result.getSamlRequest()).isNotEmpty(); - assertThat(result.getRelayState()).isNotNull(); - assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); - assertThat(result.getSignature()).isNotEmpty(); - assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); - assertThat(result.getId()).isNotEmpty(); - } - - @Test - public void resolveAuthenticationRequestWhenUnsignedRedirectThenRedirectsAndNoSignature() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setPathInfo("/saml2/authenticate/registration-id"); - RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder - .assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(false)) - .build(); - OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); - Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { - UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); - assertThat(authnRequest.getNameIDPolicy().getFormat()).isEqualTo(registration.getNameIdFormat()); - assertThat(authnRequest.getAssertionConsumerServiceURL()) - .isEqualTo(uriResolver.resolve(registration.getAssertionConsumerServiceLocation())); - assertThat(authnRequest.getProtocolBinding()) - .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn()); - assertThat(authnRequest.getDestination()) - .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); - assertThat(authnRequest.getIssuer().getValue()).isEqualTo(uriResolver.resolve(registration.getEntityId())); - }); - assertThat(result.getSamlRequest()).isNotEmpty(); - assertThat(result.getRelayState()).isNotNull(); - assertThat(result.getSigAlg()).isNull(); - assertThat(result.getSignature()).isNull(); - assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); - assertThat(result.getId()).isNotEmpty(); - } - - @Test - public void resolveAuthenticationRequestWhenSignedThenCredentialIsRequired() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setPathInfo("/saml2/authenticate/registration-id"); - Saml2X509Credential credential = TestSaml2X509Credentials.relyingPartyVerifyingCredential(); - RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials() - .assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))) - .build(); - OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); - assertThatExceptionOfType(Saml2Exception.class) - .isThrownBy(() -> resolver.resolve(request, (r, authnRequest) -> { - })); - } - - @Test - public void resolveAuthenticationRequestWhenUnsignedPostThenOnlyPosts() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setPathInfo("/saml2/authenticate/registration-id"); - RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.assertingPartyDetails( - (party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST).wantAuthnRequestsSigned(false)) - .authnRequestsSigned(false) - .build(); - OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); - Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { - UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); - assertThat(authnRequest.getNameIDPolicy().getFormat()).isEqualTo(registration.getNameIdFormat()); - assertThat(authnRequest.getAssertionConsumerServiceURL()) - .isEqualTo(uriResolver.resolve(registration.getAssertionConsumerServiceLocation())); - assertThat(authnRequest.getProtocolBinding()) - .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn()); - assertThat(authnRequest.getDestination()) - .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); - assertThat(authnRequest.getIssuer().getValue()).isEqualTo(uriResolver.resolve(registration.getEntityId())); - }); - assertThat(result.getSamlRequest()).isNotEmpty(); - assertThat(result.getRelayState()).isNotNull(); - assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST); - assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()))).doesNotContain("Signature"); - assertThat(result.getId()).isNotEmpty(); - } - - @ParameterizedTest - @MethodSource("provideSignRequestFlags") - public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts(boolean wantAuthRequestsSigned, - boolean authnRequestsSigned) { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setPathInfo("/saml2/authenticate/registration-id"); - RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder - .authnRequestsSigned(authnRequestsSigned) - .assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST) - .wantAuthnRequestsSigned(wantAuthRequestsSigned)) - .build(); - OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); - Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { - UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); - assertThat(authnRequest.getNameIDPolicy().getFormat()).isEqualTo(registration.getNameIdFormat()); - assertThat(authnRequest.getAssertionConsumerServiceURL()) - .isEqualTo(uriResolver.resolve(registration.getAssertionConsumerServiceLocation())); - assertThat(authnRequest.getProtocolBinding()) - .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn()); - assertThat(authnRequest.getDestination()) - .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); - assertThat(authnRequest.getIssuer().getValue()).isEqualTo(uriResolver.resolve(registration.getEntityId())); - }); - assertThat(result.getSamlRequest()).isNotEmpty(); - assertThat(result.getRelayState()).isNotNull(); - assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST); - assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()))).contains("Signature"); - assertThat(result.getId()).isNotEmpty(); - } - - @Test - public void resolveAuthenticationRequestWhenSHA1SignRequestThenSigns() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setPathInfo("/saml2/authenticate/registration-id"); - RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.assertingPartyDetails( - (party) -> party.signingAlgorithms((algs) -> algs.add(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1))) - .build(); - OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); - Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { - }); - assertThat(result.getSamlRequest()).isNotEmpty(); - assertThat(result.getRelayState()).isNotNull(); - assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1); - assertThat(result.getSignature()).isNotNull(); - assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); - assertThat(result.getId()).isNotEmpty(); - } - - @Test - public void resolveAuthenticationRequestWhenSignedAndRelayStateIsNullThenSignsWithoutRelayState() { - try (MockedStatic openSamlSigningUtilsMockedStatic = mockStatic( - OpenSamlSigningUtils.class, Answers.CALLS_REAL_METHODS)) { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setPathInfo("/saml2/authenticate/registration-id"); - RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder - .assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(true)) - .build(); - OpenSamlSigningUtils.QueryParametersPartial queryParametersPartialSpy = spy( - new OpenSamlSigningUtils.QueryParametersPartial(registration)); - openSamlSigningUtilsMockedStatic.when(() -> OpenSamlSigningUtils.sign(any())) - .thenReturn(queryParametersPartialSpy); - OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); - resolver.setRelayStateResolver((source) -> null); - Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { - }); - assertThat(result.getSamlRequest()).isNotEmpty(); - assertThat(result.getRelayState()).isNull(); - assertThat(result.getSigAlg()).isNotNull(); - assertThat(result.getSignature()).isNotNull(); - assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); - verify(queryParametersPartialSpy, never()).param(eq(Saml2ParameterNames.RELAY_STATE), any()); - } - } - - @Test - public void resolveAuthenticationRequestWhenSignedAndRelayStateIsEmptyThenSignsWithEmptyRelayState() { - try (MockedStatic openSamlSigningUtilsMockedStatic = mockStatic( - OpenSamlSigningUtils.class, Answers.CALLS_REAL_METHODS)) { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setPathInfo("/saml2/authenticate/registration-id"); - RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder - .assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(true)) - .build(); - OpenSamlSigningUtils.QueryParametersPartial queryParametersPartialSpy = spy( - new OpenSamlSigningUtils.QueryParametersPartial(registration)); - openSamlSigningUtilsMockedStatic.when(() -> OpenSamlSigningUtils.sign(any())) - .thenReturn(queryParametersPartialSpy); - OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); - resolver.setRelayStateResolver((source) -> ""); - Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { - }); - assertThat(result.getSamlRequest()).isNotEmpty(); - assertThat(result.getRelayState()).isEmpty(); - assertThat(result.getSigAlg()).isNotNull(); - assertThat(result.getSignature()).isNotNull(); - assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); - verify(queryParametersPartialSpy).param(eq(Saml2ParameterNames.RELAY_STATE), eq("")); - } - } - - private OpenSamlAuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistration registration) { - return new OpenSamlAuthenticationRequestResolver((request, id) -> registration); - } - - private static Stream provideSignRequestFlags() { - return Stream.of(Arguments.of(true, true), Arguments.of(true, false), Arguments.of(false, true)); - } - -}