diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index 884df70a51..fa5c9fd1f4 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -13,6 +13,12 @@ sourceSets.configureEach { set -> filter { line -> line.replaceAll(".saml2.internal", ".saml2.provider.service.authentication.logout") } with from } + + copy { + into "$projectDir/src/$set.name/java/org/springframework/security/saml2/provider/service/authentication" + filter { line -> line.replaceAll(".saml2.internal", ".saml2.provider.service.authentication") } + with from + } } dependencies { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/BaseOpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/BaseOpenSamlAuthenticationProvider.java new file mode 100644 index 0000000000..fa6035f95c --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/BaseOpenSamlAuthenticationProvider.java @@ -0,0 +1,678 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import javax.annotation.Nonnull; +import javax.xml.namespace.QName; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.schema.XSAny; +import org.opensaml.core.xml.schema.XSBoolean; +import org.opensaml.core.xml.schema.XSBooleanValue; +import org.opensaml.core.xml.schema.XSDateTime; +import org.opensaml.core.xml.schema.XSInteger; +import org.opensaml.core.xml.schema.XSString; +import org.opensaml.core.xml.schema.XSURI; +import org.opensaml.saml.common.assertion.AssertionValidationException; +import org.opensaml.saml.common.assertion.ValidationContext; +import org.opensaml.saml.common.assertion.ValidationResult; +import org.opensaml.saml.saml2.assertion.ConditionValidator; +import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator; +import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters; +import org.opensaml.saml.saml2.assertion.StatementValidator; +import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator; +import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator; +import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator; +import org.opensaml.saml.saml2.assertion.impl.DelegationRestrictionConditionValidator; +import org.opensaml.saml.saml2.assertion.impl.ProxyRestrictionConditionValidator; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.AuthnStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.OneTimeUse; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusCode; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.core.SubjectConfirmationData; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.core.log.LogMessage; +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.OpenSamlInitializationService; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.StringUtils; + +class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider { + + static { + OpenSamlInitializationService.initialize(); + } + + private final Log logger = LogFactory.getLog(this.getClass()); + + private final OpenSamlOperations saml; + + private final Converter responseSignatureValidator = createDefaultResponseSignatureValidator(); + + private Consumer responseElementsDecrypter = createDefaultResponseElementsDecrypter(); + + private Converter responseValidator = createDefaultResponseValidator(); + + private final Converter assertionSignatureValidator = createDefaultAssertionSignatureValidator(); + + private Consumer assertionElementsDecrypter = createDefaultAssertionElementsDecrypter(); + + private Converter assertionValidator = createDefaultAssertionValidator(); + + private Converter responseAuthenticationConverter = createDefaultResponseAuthenticationConverter(); + + private static final Set includeChildStatusCodes = new HashSet<>( + Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH)); + + BaseOpenSamlAuthenticationProvider(OpenSamlOperations saml) { + this.saml = saml; + } + + void setResponseElementsDecrypter(Consumer responseElementsDecrypter) { + Assert.notNull(responseElementsDecrypter, "responseElementsDecrypter cannot be null"); + this.responseElementsDecrypter = responseElementsDecrypter; + } + + void setResponseValidator(Converter responseValidator) { + Assert.notNull(responseValidator, "responseValidator cannot be null"); + this.responseValidator = responseValidator; + } + + void setAssertionValidator(Converter assertionValidator) { + Assert.notNull(assertionValidator, "assertionValidator cannot be null"); + this.assertionValidator = assertionValidator; + } + + void setAssertionElementsDecrypter(Consumer assertionDecrypter) { + Assert.notNull(assertionDecrypter, "assertionDecrypter cannot be null"); + this.assertionElementsDecrypter = assertionDecrypter; + } + + void setResponseAuthenticationConverter( + Converter responseAuthenticationConverter) { + Assert.notNull(responseAuthenticationConverter, "responseAuthenticationConverter cannot be null"); + this.responseAuthenticationConverter = responseAuthenticationConverter; + } + + static Converter createDefaultResponseValidator() { + return (responseToken) -> { + Response response = responseToken.getResponse(); + Saml2AuthenticationToken token = responseToken.getToken(); + Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success(); + List statusCodes = getStatusCodes(response); + if (!isSuccess(statusCodes)) { + for (String statusCode : statusCodes) { + String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode, + response.getID()); + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); + } + } + + String inResponseTo = response.getInResponseTo(); + result = result.concat(validateInResponseTo(token.getAuthenticationRequest(), inResponseTo)); + + String issuer = response.getIssuer().getValue(); + String destination = response.getDestination(); + String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); + if (StringUtils.hasText(destination) && !destination.equals(location)) { + String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + + "]"; + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message)); + } + String assertingPartyEntityId = token.getRelyingPartyRegistration() + .getAssertingPartyMetadata() + .getEntityId(); + if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) { + String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID()); + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message)); + } + if (response.getAssertions().isEmpty()) { + result = result.concat( + new Saml2Error(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response.")); + } + return result; + }; + } + + private static List getStatusCodes(Response response) { + if (response.getStatus() == null) { + return List.of(StatusCode.SUCCESS); + } + if (response.getStatus().getStatusCode() == null) { + return List.of(StatusCode.SUCCESS); + } + StatusCode parentStatusCode = response.getStatus().getStatusCode(); + String parentStatusCodeValue = parentStatusCode.getValue(); + if (!includeChildStatusCodes.contains(parentStatusCodeValue)) { + return List.of(parentStatusCodeValue); + } + StatusCode childStatusCode = parentStatusCode.getStatusCode(); + if (childStatusCode == null) { + return List.of(parentStatusCodeValue); + } + String childStatusCodeValue = childStatusCode.getValue(); + if (childStatusCodeValue == null) { + return List.of(parentStatusCodeValue); + } + return List.of(parentStatusCodeValue, childStatusCodeValue); + } + + private static boolean isSuccess(List statusCodes) { + if (statusCodes.size() != 1) { + return false; + } + + String statusCode = statusCodes.get(0); + return StatusCode.SUCCESS.equals(statusCode); + } + + private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest, + String inResponseTo) { + if (!StringUtils.hasText(inResponseTo)) { + return Saml2ResponseValidatorResult.success(); + } + if (storedRequest == null) { + String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]" + + " but no saved authentication request was found"; + return Saml2ResponseValidatorResult + .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); + } + if (!inResponseTo.equals(storedRequest.getId())) { + String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the " + + "authentication request [" + storedRequest.getId() + "]"; + return Saml2ResponseValidatorResult + .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); + } + return Saml2ResponseValidatorResult.success(); + } + + static Converter createDefaultAssertionValidator() { + return createDefaultAssertionValidatorWithParameters( + (params) -> params.put(SAML2AssertionValidationParameters.CLOCK_SKEW, Duration.ofMinutes(5))); + } + + static Converter createDefaultAssertionValidator( + Converter contextConverter) { + + return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, + (assertionToken) -> SAML20AssertionValidators.attributeValidator, contextConverter); + } + + static Converter createDefaultAssertionValidatorWithParameters( + Consumer> validationContextParameters) { + return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, + (assertionToken) -> SAML20AssertionValidators.attributeValidator, + (assertionToken) -> createValidationContext(assertionToken, validationContextParameters)); + } + + static Converter createDefaultResponseAuthenticationConverter() { + return (responseToken) -> { + Response response = responseToken.response; + Saml2AuthenticationToken token = responseToken.token; + Assertion assertion = CollectionUtils.firstElement(response.getAssertions()); + String username = assertion.getSubject().getNameID().getValue(); + Map> attributes = getAssertionAttributes(assertion); + List sessionIndexes = getSessionIndexes(assertion); + DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal(username, attributes, + sessionIndexes); + String registrationId = responseToken.token.getRelyingPartyRegistration().getRegistrationId(); + principal.setRelyingPartyRegistrationId(registrationId); + return new Saml2Authentication(principal, token.getSaml2Response(), + AuthorityUtils.createAuthorityList("ROLE_USER")); + }; + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + try { + Saml2AuthenticationToken token = (Saml2AuthenticationToken) authentication; + String serializedResponse = token.getSaml2Response(); + Response response = parseResponse(serializedResponse); + process(token, response); + AbstractAuthenticationToken authenticationResponse = this.responseAuthenticationConverter + .convert(new ResponseToken(response, token)); + if (authenticationResponse != null) { + authenticationResponse.setDetails(authentication.getDetails()); + } + return authenticationResponse; + } + catch (Saml2AuthenticationException ex) { + throw ex; + } + catch (Exception ex) { + throw createAuthenticationException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, ex.getMessage(), ex); + } + } + + @Override + public boolean supports(Class authentication) { + return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication); + } + + private Response parseResponse(String response) throws Saml2Exception, Saml2AuthenticationException { + try { + return this.saml.deserialize(response); + } + catch (Exception ex) { + throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, ex.getMessage(), ex); + } + } + + private void process(Saml2AuthenticationToken token, Response response) { + String issuer = response.getIssuer().getValue(); + this.logger.debug(LogMessage.format("Processing SAML response from %s", issuer)); + boolean responseSigned = response.isSigned(); + + ResponseToken responseToken = new ResponseToken(response, token); + Saml2ResponseValidatorResult result = this.responseSignatureValidator.convert(responseToken); + if (responseSigned) { + this.responseElementsDecrypter.accept(responseToken); + } + else if (!response.getEncryptedAssertions().isEmpty()) { + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Did not decrypt response [" + response.getID() + "] since it is not signed")); + } + result = result.concat(this.responseValidator.convert(responseToken)); + boolean allAssertionsSigned = true; + for (Assertion assertion : response.getAssertions()) { + AssertionToken assertionToken = new AssertionToken(assertion, token); + result = result.concat(this.assertionSignatureValidator.convert(assertionToken)); + allAssertionsSigned = allAssertionsSigned && assertion.isSigned(); + if (responseSigned || assertion.isSigned()) { + this.assertionElementsDecrypter.accept(new AssertionToken(assertion, token)); + } + result = result.concat(this.assertionValidator.convert(assertionToken)); + } + if (!responseSigned && !allAssertionsSigned) { + String description = "Either the response or one of the assertions is unsigned. " + + "Please either sign the response or all of the assertions."; + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, description)); + } + Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions()); + if (firstAssertion != null && !hasName(firstAssertion)) { + Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND, + "Assertion [" + firstAssertion.getID() + "] is missing a subject"); + result = result.concat(error); + } + + if (result.hasErrors()) { + Collection errors = result.getErrors(); + if (this.logger.isTraceEnabled()) { + this.logger.trace("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + + "]: " + errors); + } + else if (this.logger.isDebugEnabled()) { + this.logger + .debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]"); + } + Saml2Error first = errors.iterator().next(); + throw createAuthenticationException(first.getErrorCode(), first.getDescription(), null); + } + else { + if (this.logger.isDebugEnabled()) { + this.logger.debug("Successfully processed SAML Response [" + response.getID() + "]"); + } + } + } + + private Converter createDefaultResponseSignatureValidator() { + return (responseToken) -> { + Response response = responseToken.getResponse(); + RelyingPartyRegistration registration = responseToken.getToken().getRelyingPartyRegistration(); + if (response.isSigned()) { + AssertingPartyMetadata details = registration.getAssertingPartyMetadata(); + Collection credentials = details.getVerificationX509Credentials(); + Collection errors = this.saml.withVerificationKeys(credentials) + .entityId(details.getEntityId()) + .verify(response); + return Saml2ResponseValidatorResult.failure(errors); + } + return Saml2ResponseValidatorResult.success(); + }; + } + + private Consumer createDefaultResponseElementsDecrypter() { + return (responseToken) -> { + Response response = responseToken.getResponse(); + RelyingPartyRegistration registration = responseToken.getToken().getRelyingPartyRegistration(); + try { + this.saml.withDecryptionKeys(registration.getDecryptionX509Credentials()).decrypt(response); + } + catch (Exception ex) { + throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); + } + }; + } + + private Converter createDefaultAssertionSignatureValidator() { + return (assertionToken) -> { + RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration(); + Assertion assertion = assertionToken.getAssertion(); + if (assertion.isSigned()) { + AssertingPartyMetadata details = registration.getAssertingPartyMetadata(); + Collection credentials = details.getVerificationX509Credentials(); + Collection errors = this.saml.withVerificationKeys(credentials) + .entityId(details.getEntityId()) + .verify(assertion); + return Saml2ResponseValidatorResult.failure(errors); + } + return Saml2ResponseValidatorResult.success(); + }; + } + + private Consumer createDefaultAssertionElementsDecrypter() { + return (assertionToken) -> { + Assertion assertion = assertionToken.getAssertion(); + RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration(); + try { + this.saml.withDecryptionKeys(registration.getDecryptionX509Credentials()).decrypt(assertion); + } + catch (Exception ex) { + throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); + } + }; + } + + private boolean hasName(Assertion assertion) { + if (assertion == null) { + return false; + } + if (assertion.getSubject() == null) { + return false; + } + if (assertion.getSubject().getNameID() == null) { + return false; + } + return assertion.getSubject().getNameID().getValue() != null; + } + + private static Map> getAssertionAttributes(Assertion assertion) { + MultiValueMap attributeMap = new LinkedMultiValueMap<>(); + for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) { + for (Attribute attribute : attributeStatement.getAttributes()) { + List attributeValues = new ArrayList<>(); + for (XMLObject xmlObject : attribute.getAttributeValues()) { + Object attributeValue = getXmlObjectValue(xmlObject); + if (attributeValue != null) { + attributeValues.add(attributeValue); + } + } + attributeMap.addAll(attribute.getName(), attributeValues); + } + } + return new LinkedHashMap<>(attributeMap); // gh-11785 + } + + private static List getSessionIndexes(Assertion assertion) { + List sessionIndexes = new ArrayList<>(); + for (AuthnStatement statement : assertion.getAuthnStatements()) { + sessionIndexes.add(statement.getSessionIndex()); + } + return sessionIndexes; + } + + private static Object getXmlObjectValue(XMLObject xmlObject) { + if (xmlObject instanceof XSAny) { + return ((XSAny) xmlObject).getTextContent(); + } + if (xmlObject instanceof XSString) { + return ((XSString) xmlObject).getValue(); + } + if (xmlObject instanceof XSInteger) { + return ((XSInteger) xmlObject).getValue(); + } + if (xmlObject instanceof XSURI) { + return ((XSURI) xmlObject).getURI(); + } + if (xmlObject instanceof XSBoolean) { + XSBooleanValue xsBooleanValue = ((XSBoolean) xmlObject).getValue(); + return (xsBooleanValue != null) ? xsBooleanValue.getValue() : null; + } + if (xmlObject instanceof XSDateTime) { + return ((XSDateTime) xmlObject).getValue(); + } + return xmlObject; + } + + private static Saml2AuthenticationException createAuthenticationException(String code, String message, + Exception cause) { + return new Saml2AuthenticationException(new Saml2Error(code, message), cause); + } + + private static Converter createAssertionValidator(String errorCode, + Converter validatorConverter, + Converter contextConverter) { + + return (assertionToken) -> { + Assertion assertion = assertionToken.assertion; + SAML20AssertionValidator validator = validatorConverter.convert(assertionToken); + ValidationContext context = contextConverter.convert(assertionToken); + try { + ValidationResult result = validator.validate(assertion, context); + if (result == ValidationResult.VALID) { + return Saml2ResponseValidatorResult.success(); + } + } + catch (Exception ex) { + String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(), + ((Response) assertion.getParent()).getID(), ex.getMessage()); + return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message)); + } + String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(), + ((Response) assertion.getParent()).getID(), contextToString(context)); + return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message)); + }; + } + + private static String contextToString(ValidationContext context) { + StringBuilder sb = new StringBuilder(); + for (Field field : context.getClass().getDeclaredFields()) { + ReflectionUtils.makeAccessible(field); + Object value = ReflectionUtils.getField(field, context); + sb.append(field.getName() + " = " + value + ","); + } + sb.deleteCharAt(sb.length() - 1); + return sb.toString(); + } + + private static ValidationContext createValidationContext(AssertionToken assertionToken, + Consumer> paramsConsumer) { + Saml2AuthenticationToken token = assertionToken.token; + RelyingPartyRegistration relyingPartyRegistration = token.getRelyingPartyRegistration(); + String audience = relyingPartyRegistration.getEntityId(); + String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation(); + String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyMetadata().getEntityId(); + Map params = new HashMap<>(); + Assertion assertion = assertionToken.getAssertion(); + if (assertionContainsInResponseTo(assertion)) { + String requestId = getAuthnRequestId(token.getAuthenticationRequest()); + params.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId); + } + params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience)); + params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient)); + params.put(SAML2AssertionValidationParameters.VALID_ISSUERS, Collections.singleton(assertingPartyEntityId)); + paramsConsumer.accept(params); + return new ValidationContext(params); + } + + private static boolean assertionContainsInResponseTo(Assertion assertion) { + if (assertion.getSubject() == null) { + return false; + } + for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) { + SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData(); + if (confirmationData == null) { + continue; + } + if (StringUtils.hasText(confirmationData.getInResponseTo())) { + return true; + } + } + return false; + } + + private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) { + return (serialized != null) ? serialized.getId() : null; + } + + private static class SAML20AssertionValidators { + + private static final Collection conditions = new ArrayList<>(); + + private static final Collection subjects = new ArrayList<>(); + + private static final Collection statements = new ArrayList<>(); + + static { + conditions.add(new AudienceRestrictionConditionValidator()); + conditions.add(new DelegationRestrictionConditionValidator()); + conditions.add(new ConditionValidator() { + @Nonnull + @Override + public QName getServicedCondition() { + return OneTimeUse.DEFAULT_ELEMENT_NAME; + } + + @Nonnull + @Override + public ValidationResult validate(Condition condition, Assertion assertion, ValidationContext context) { + // applications should validate their own OneTimeUse conditions + return ValidationResult.VALID; + } + }); + conditions.add(new ProxyRestrictionConditionValidator()); + subjects.add(new BearerSubjectConfirmationValidator() { + @Nonnull + protected ValidationResult validateAddress(@Nonnull SubjectConfirmation confirmation, + @Nonnull Assertion assertion, @Nonnull ValidationContext context, boolean required) + throws AssertionValidationException { + return ValidationResult.VALID; + } + + @Nonnull + protected ValidationResult validateAddress(@Nonnull SubjectConfirmationData confirmationData, + @Nonnull Assertion assertion, @Nonnull ValidationContext context, boolean required) + throws AssertionValidationException { + // applications should validate their own addresses - gh-7514 + return ValidationResult.VALID; + } + }); + } + + private static final SAML20AssertionValidator attributeValidator = new SAML20AssertionValidator(conditions, + subjects, statements, null, null, null) { + @Nonnull + @Override + protected ValidationResult validateSignature(Assertion token, ValidationContext context) { + return ValidationResult.VALID; + } + }; + + } + + /** + * A tuple containing an OpenSAML {@link Response} and its associated authentication + * token. + * + * @since 5.4 + */ + static class ResponseToken { + + private final Saml2AuthenticationToken token; + + private final Response response; + + ResponseToken(Response response, Saml2AuthenticationToken token) { + this.token = token; + this.response = response; + } + + Response getResponse() { + return this.response; + } + + Saml2AuthenticationToken getToken() { + return this.token; + } + + } + + /** + * A tuple containing an OpenSAML {@link Assertion} and its associated authentication + * token. + * + * @since 5.4 + */ + static class AssertionToken { + + private final Saml2AuthenticationToken token; + + private final Assertion assertion; + + AssertionToken(Assertion assertion, Saml2AuthenticationToken token) { + this.token = token; + this.assertion = assertion; + } + + Assertion getAssertion() { + return this.assertion; + } + + Saml2AuthenticationToken getToken() { + return this.token; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index a2a390b1bb..b0eede6742 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -16,88 +16,24 @@ package org.springframework.security.saml2.provider.service.authentication; -import java.io.ByteArrayInputStream; -import java.nio.charset.StandardCharsets; import java.time.Duration; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; -import java.util.Set; import java.util.function.Consumer; -import javax.annotation.Nonnull; -import javax.xml.namespace.QName; - -import net.shibboleth.utilities.java.support.xml.ParserPool; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.opensaml.core.config.ConfigurationService; -import org.opensaml.core.xml.XMLObject; -import org.opensaml.core.xml.config.XMLObjectProviderRegistry; -import org.opensaml.core.xml.schema.XSAny; -import org.opensaml.core.xml.schema.XSBoolean; -import org.opensaml.core.xml.schema.XSBooleanValue; -import org.opensaml.core.xml.schema.XSDateTime; -import org.opensaml.core.xml.schema.XSInteger; -import org.opensaml.core.xml.schema.XSString; -import org.opensaml.core.xml.schema.XSURI; -import org.opensaml.saml.common.assertion.AssertionValidationException; import org.opensaml.saml.common.assertion.ValidationContext; -import org.opensaml.saml.common.assertion.ValidationResult; -import org.opensaml.saml.saml2.assertion.ConditionValidator; -import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator; import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters; -import org.opensaml.saml.saml2.assertion.StatementValidator; -import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator; -import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator; -import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator; -import org.opensaml.saml.saml2.assertion.impl.DelegationRestrictionConditionValidator; -import org.opensaml.saml.saml2.assertion.impl.ProxyRestrictionConditionValidator; import org.opensaml.saml.saml2.core.Assertion; -import org.opensaml.saml.saml2.core.Attribute; -import org.opensaml.saml.saml2.core.AttributeStatement; -import org.opensaml.saml.saml2.core.AuthnRequest; -import org.opensaml.saml.saml2.core.AuthnStatement; -import org.opensaml.saml.saml2.core.Condition; import org.opensaml.saml.saml2.core.EncryptedAssertion; -import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; -import org.opensaml.saml.saml2.core.StatusCode; -import org.opensaml.saml.saml2.core.SubjectConfirmation; -import org.opensaml.saml.saml2.core.SubjectConfirmationData; -import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller; -import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller; import org.opensaml.saml.saml2.encryption.Decrypter; -import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; -import org.opensaml.xmlsec.signature.support.SignaturePrevalidator; -import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; -import org.w3c.dom.Document; -import org.w3c.dom.Element; import org.springframework.core.convert.converter.Converter; -import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.saml2.Saml2Exception; -import org.springframework.security.saml2.core.OpenSamlInitializationService; -import org.springframework.security.saml2.core.Saml2Error; -import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; /** * Implementation of {@link AuthenticationProvider} for SAML authentications when @@ -142,48 +78,13 @@ import org.springframework.util.StringUtils; */ public final class OpenSaml4AuthenticationProvider implements AuthenticationProvider { - static { - OpenSamlInitializationService.initialize(); - } - - private final Log logger = LogFactory.getLog(this.getClass()); - - private final ResponseUnmarshaller responseUnmarshaller; - - private static final AuthnRequestUnmarshaller authnRequestUnmarshaller; - static { - XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); - authnRequestUnmarshaller = (AuthnRequestUnmarshaller) registry.getUnmarshallerFactory() - .getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); - } - - private final ParserPool parserPool; - - private final Converter responseSignatureValidator = createDefaultResponseSignatureValidator(); - - private Consumer responseElementsDecrypter = createDefaultResponseElementsDecrypter(); - - private Converter responseValidator = createDefaultResponseValidator(); - - private final Converter assertionSignatureValidator = createDefaultAssertionSignatureValidator(); - - private Consumer assertionElementsDecrypter = createDefaultAssertionElementsDecrypter(); - - private Converter assertionValidator = createDefaultAssertionValidator(); - - private Converter responseAuthenticationConverter = createDefaultResponseAuthenticationConverter(); - - private static final Set includeChildStatusCodes = new HashSet<>( - Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH)); + private final BaseOpenSamlAuthenticationProvider delegate; /** * Creates an {@link OpenSaml4AuthenticationProvider} */ public OpenSaml4AuthenticationProvider() { - XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); - this.responseUnmarshaller = (ResponseUnmarshaller) registry.getUnmarshallerFactory() - .getUnmarshaller(Response.DEFAULT_ELEMENT_NAME); - this.parserPool = registry.getParserPool(); + this.delegate = new BaseOpenSamlAuthenticationProvider(new OpenSaml4Template()); } /** @@ -231,7 +132,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv */ public void setResponseElementsDecrypter(Consumer responseElementsDecrypter) { Assert.notNull(responseElementsDecrypter, "responseElementsDecrypter cannot be null"); - this.responseElementsDecrypter = responseElementsDecrypter; + this.delegate + .setResponseElementsDecrypter((token) -> responseElementsDecrypter.accept(new ResponseToken(token))); } /** @@ -253,7 +155,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv */ public void setResponseValidator(Converter responseValidator) { Assert.notNull(responseValidator, "responseValidator cannot be null"); - this.responseValidator = responseValidator; + this.delegate.setResponseValidator((token) -> responseValidator.convert(new ResponseToken(token))); } /** @@ -297,7 +199,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv */ public void setAssertionValidator(Converter assertionValidator) { Assert.notNull(assertionValidator, "assertionValidator cannot be null"); - this.assertionValidator = assertionValidator; + this.delegate.setAssertionValidator((token) -> assertionValidator.convert(new AssertionToken(token))); } /** @@ -340,7 +242,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv */ public void setAssertionElementsDecrypter(Consumer assertionDecrypter) { Assert.notNull(assertionDecrypter, "assertionDecrypter cannot be null"); - this.assertionElementsDecrypter = assertionDecrypter; + this.delegate.setAssertionElementsDecrypter((token) -> assertionDecrypter.accept(new AssertionToken(token))); } /** @@ -366,7 +268,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv public void setResponseAuthenticationConverter( Converter responseAuthenticationConverter) { Assert.notNull(responseAuthenticationConverter, "responseAuthenticationConverter cannot be null"); - this.responseAuthenticationConverter = responseAuthenticationConverter; + this.delegate.setResponseAuthenticationConverter( + (token) -> responseAuthenticationConverter.convert(new ResponseToken(token))); } /** @@ -375,95 +278,10 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv * @since 5.6 */ public static Converter createDefaultResponseValidator() { - return (responseToken) -> { - Response response = responseToken.getResponse(); - Saml2AuthenticationToken token = responseToken.getToken(); - Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success(); - List statusCodes = getStatusCodes(response); - if (!isSuccess(statusCodes)) { - for (String statusCode : statusCodes) { - String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode, - response.getID()); - result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); - } - } - - String inResponseTo = response.getInResponseTo(); - result = result.concat(validateInResponseTo(token.getAuthenticationRequest(), inResponseTo)); - - String issuer = response.getIssuer().getValue(); - String destination = response.getDestination(); - String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); - if (StringUtils.hasText(destination) && !destination.equals(location)) { - String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() - + "]"; - result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message)); - } - String assertingPartyEntityId = token.getRelyingPartyRegistration() - .getAssertingPartyMetadata() - .getEntityId(); - if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) { - String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID()); - result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message)); - } - if (response.getAssertions().isEmpty()) { - result = result.concat( - new Saml2Error(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response.")); - } - return result; - }; - } - - private static List getStatusCodes(Response response) { - if (response.getStatus() == null) { - return List.of(StatusCode.SUCCESS); - } - if (response.getStatus().getStatusCode() == null) { - return List.of(StatusCode.SUCCESS); - } - StatusCode parentStatusCode = response.getStatus().getStatusCode(); - String parentStatusCodeValue = parentStatusCode.getValue(); - if (!includeChildStatusCodes.contains(parentStatusCodeValue)) { - return List.of(parentStatusCodeValue); - } - StatusCode childStatusCode = parentStatusCode.getStatusCode(); - if (childStatusCode == null) { - return List.of(parentStatusCodeValue); - } - String childStatusCodeValue = childStatusCode.getValue(); - if (childStatusCodeValue == null) { - return List.of(parentStatusCodeValue); - } - return List.of(parentStatusCodeValue, childStatusCodeValue); - } - - private static boolean isSuccess(List statusCodes) { - if (statusCodes.size() != 1) { - return false; - } - - String statusCode = statusCodes.get(0); - return StatusCode.SUCCESS.equals(statusCode); - } - - private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest, - String inResponseTo) { - if (!StringUtils.hasText(inResponseTo)) { - return Saml2ResponseValidatorResult.success(); - } - if (storedRequest == null) { - String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]" - + " but no saved authentication request was found"; - return Saml2ResponseValidatorResult - .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); - } - if (!inResponseTo.equals(storedRequest.getId())) { - String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the " - + "authentication request [" + storedRequest.getId() + "]"; - return Saml2ResponseValidatorResult - .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); - } - return Saml2ResponseValidatorResult.success(); + Converter delegate = BaseOpenSamlAuthenticationProvider + .createDefaultResponseValidator(); + return (token) -> delegate + .convert(new BaseOpenSamlAuthenticationProvider.ResponseToken(token.getResponse(), token.getToken())); } /** @@ -472,7 +290,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv * @return the default assertion validator strategy */ public static Converter createDefaultAssertionValidator() { - return createDefaultAssertionValidatorWithParameters( (params) -> params.put(SAML2AssertionValidationParameters.CLOCK_SKEW, Duration.ofMinutes(5))); } @@ -488,9 +305,12 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv @Deprecated public static Converter createDefaultAssertionValidator( Converter contextConverter) { - - return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, - (assertionToken) -> SAML20AssertionValidators.attributeValidator, contextConverter); + Converter contextDelegate = ( + token) -> contextConverter.convert(new AssertionToken(token.getAssertion(), token.getToken())); + Converter delegate = BaseOpenSamlAuthenticationProvider + .createDefaultAssertionValidator(contextDelegate); + return (token) -> delegate + .convert(new BaseOpenSamlAuthenticationProvider.AssertionToken(token.getAssertion(), token.getToken())); } /** @@ -503,9 +323,10 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv */ public static Converter createDefaultAssertionValidatorWithParameters( Consumer> validationContextParameters) { - return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, - (assertionToken) -> SAML20AssertionValidators.attributeValidator, - (assertionToken) -> createValidationContext(assertionToken, validationContextParameters)); + Converter delegate = BaseOpenSamlAuthenticationProvider + .createDefaultAssertionValidatorWithParameters(validationContextParameters); + return (token) -> delegate + .convert(new BaseOpenSamlAuthenticationProvider.AssertionToken(token.getAssertion(), token.getToken())); } /** @@ -514,20 +335,10 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv * @return the default response authentication converter strategy */ public static Converter createDefaultResponseAuthenticationConverter() { - return (responseToken) -> { - Response response = responseToken.response; - Saml2AuthenticationToken token = responseToken.token; - Assertion assertion = CollectionUtils.firstElement(response.getAssertions()); - String username = assertion.getSubject().getNameID().getValue(); - Map> attributes = getAssertionAttributes(assertion); - List sessionIndexes = getSessionIndexes(assertion); - DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal(username, attributes, - sessionIndexes); - String registrationId = responseToken.token.getRelyingPartyRegistration().getRegistrationId(); - principal.setRelyingPartyRegistrationId(registrationId); - return new Saml2Authentication(principal, token.getSaml2Response(), - AuthorityUtils.createAuthorityList("ROLE_USER")); - }; + Converter delegate = BaseOpenSamlAuthenticationProvider + .createDefaultResponseAuthenticationConverter(); + return (token) -> delegate + .convert(new BaseOpenSamlAuthenticationProvider.ResponseToken(token.getResponse(), token.getToken())); } /** @@ -538,24 +349,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv */ @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - try { - Saml2AuthenticationToken token = (Saml2AuthenticationToken) authentication; - String serializedResponse = token.getSaml2Response(); - Response response = parseResponse(serializedResponse); - process(token, response); - AbstractAuthenticationToken authenticationResponse = this.responseAuthenticationConverter - .convert(new ResponseToken(response, token)); - if (authenticationResponse != null) { - authenticationResponse.setDetails(authentication.getDetails()); - } - return authenticationResponse; - } - catch (Saml2AuthenticationException ex) { - throw ex; - } - catch (Exception ex) { - throw createAuthenticationException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, ex.getMessage(), ex); - } + return this.delegate.authenticate(authentication); } @Override @@ -563,337 +357,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication); } - private Response parseResponse(String response) throws Saml2Exception, Saml2AuthenticationException { - try { - Document document = this.parserPool - .parse(new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))); - Element element = document.getDocumentElement(); - return (Response) this.responseUnmarshaller.unmarshall(element); - } - catch (Exception ex) { - throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, ex.getMessage(), ex); - } - } - - private void process(Saml2AuthenticationToken token, Response response) { - String issuer = response.getIssuer().getValue(); - this.logger.debug(LogMessage.format("Processing SAML response from %s", issuer)); - boolean responseSigned = response.isSigned(); - - ResponseToken responseToken = new ResponseToken(response, token); - Saml2ResponseValidatorResult result = this.responseSignatureValidator.convert(responseToken); - if (responseSigned) { - this.responseElementsDecrypter.accept(responseToken); - } - else if (!response.getEncryptedAssertions().isEmpty()) { - result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Did not decrypt response [" + response.getID() + "] since it is not signed")); - } - result = result.concat(this.responseValidator.convert(responseToken)); - boolean allAssertionsSigned = true; - for (Assertion assertion : response.getAssertions()) { - AssertionToken assertionToken = new AssertionToken(assertion, token); - result = result.concat(this.assertionSignatureValidator.convert(assertionToken)); - allAssertionsSigned = allAssertionsSigned && assertion.isSigned(); - if (responseSigned || assertion.isSigned()) { - this.assertionElementsDecrypter.accept(new AssertionToken(assertion, token)); - } - result = result.concat(this.assertionValidator.convert(assertionToken)); - } - if (!responseSigned && !allAssertionsSigned) { - String description = "Either the response or one of the assertions is unsigned. " - + "Please either sign the response or all of the assertions."; - result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, description)); - } - Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions()); - if (firstAssertion != null && !hasName(firstAssertion)) { - Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND, - "Assertion [" + firstAssertion.getID() + "] is missing a subject"); - result = result.concat(error); - } - - if (result.hasErrors()) { - Collection errors = result.getErrors(); - if (this.logger.isTraceEnabled()) { - this.logger.trace("Found " + errors.size() + " validation errors in SAML response [" + response.getID() - + "]: " + errors); - } - else if (this.logger.isDebugEnabled()) { - this.logger - .debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]"); - } - Saml2Error first = errors.iterator().next(); - throw createAuthenticationException(first.getErrorCode(), first.getDescription(), null); - } - else { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Successfully processed SAML Response [" + response.getID() + "]"); - } - } - } - - private Converter createDefaultResponseSignatureValidator() { - return (responseToken) -> { - Response response = responseToken.getResponse(); - RelyingPartyRegistration registration = responseToken.getToken().getRelyingPartyRegistration(); - if (response.isSigned()) { - return OpenSamlVerificationUtils.verifySignature(response, registration).post(response.getSignature()); - } - return Saml2ResponseValidatorResult.success(); - }; - } - - private Consumer createDefaultResponseElementsDecrypter() { - return (responseToken) -> { - Response response = responseToken.getResponse(); - RelyingPartyRegistration registration = responseToken.getToken().getRelyingPartyRegistration(); - try { - OpenSamlDecryptionUtils.decryptResponseElements(response, registration); - } - catch (Exception ex) { - throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); - } - }; - } - - private Converter createDefaultAssertionSignatureValidator() { - return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> { - RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration(); - SignatureTrustEngine engine = OpenSamlVerificationUtils.trustEngine(registration); - return SAML20AssertionValidators.createSignatureValidator(engine); - }, (assertionToken) -> new ValidationContext( - Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false))); - } - - private Consumer createDefaultAssertionElementsDecrypter() { - return (assertionToken) -> { - Assertion assertion = assertionToken.getAssertion(); - RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration(); - try { - OpenSamlDecryptionUtils.decryptAssertionElements(assertion, registration); - } - catch (Exception ex) { - throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); - } - }; - } - - private boolean hasName(Assertion assertion) { - if (assertion == null) { - return false; - } - if (assertion.getSubject() == null) { - return false; - } - if (assertion.getSubject().getNameID() == null) { - return false; - } - return assertion.getSubject().getNameID().getValue() != null; - } - - private static Map> getAssertionAttributes(Assertion assertion) { - MultiValueMap attributeMap = new LinkedMultiValueMap<>(); - for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) { - for (Attribute attribute : attributeStatement.getAttributes()) { - List attributeValues = new ArrayList<>(); - for (XMLObject xmlObject : attribute.getAttributeValues()) { - Object attributeValue = getXmlObjectValue(xmlObject); - if (attributeValue != null) { - attributeValues.add(attributeValue); - } - } - attributeMap.addAll(attribute.getName(), attributeValues); - } - } - return new LinkedHashMap<>(attributeMap); // gh-11785 - } - - private static List getSessionIndexes(Assertion assertion) { - List sessionIndexes = new ArrayList<>(); - for (AuthnStatement statement : assertion.getAuthnStatements()) { - sessionIndexes.add(statement.getSessionIndex()); - } - return sessionIndexes; - } - - private static Object getXmlObjectValue(XMLObject xmlObject) { - if (xmlObject instanceof XSAny) { - return ((XSAny) xmlObject).getTextContent(); - } - if (xmlObject instanceof XSString) { - return ((XSString) xmlObject).getValue(); - } - if (xmlObject instanceof XSInteger) { - return ((XSInteger) xmlObject).getValue(); - } - if (xmlObject instanceof XSURI) { - return ((XSURI) xmlObject).getURI(); - } - if (xmlObject instanceof XSBoolean) { - XSBooleanValue xsBooleanValue = ((XSBoolean) xmlObject).getValue(); - return (xsBooleanValue != null) ? xsBooleanValue.getValue() : null; - } - if (xmlObject instanceof XSDateTime) { - return ((XSDateTime) xmlObject).getValue(); - } - return xmlObject; - } - - private static Saml2AuthenticationException createAuthenticationException(String code, String message, - Exception cause) { - return new Saml2AuthenticationException(new Saml2Error(code, message), cause); - } - - private static Converter createAssertionValidator(String errorCode, - Converter validatorConverter, - Converter contextConverter) { - - return (assertionToken) -> { - Assertion assertion = assertionToken.assertion; - SAML20AssertionValidator validator = validatorConverter.convert(assertionToken); - ValidationContext context = contextConverter.convert(assertionToken); - try { - ValidationResult result = validator.validate(assertion, context); - if (result == ValidationResult.VALID) { - return Saml2ResponseValidatorResult.success(); - } - } - catch (Exception ex) { - String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(), - ((Response) assertion.getParent()).getID(), ex.getMessage()); - return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message)); - } - String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(), - ((Response) assertion.getParent()).getID(), context.getValidationFailureMessage()); - return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message)); - }; - } - - private static ValidationContext createValidationContext(AssertionToken assertionToken, - Consumer> paramsConsumer) { - Saml2AuthenticationToken token = assertionToken.token; - RelyingPartyRegistration relyingPartyRegistration = token.getRelyingPartyRegistration(); - String audience = relyingPartyRegistration.getEntityId(); - String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation(); - String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyMetadata().getEntityId(); - Map params = new HashMap<>(); - Assertion assertion = assertionToken.getAssertion(); - if (assertionContainsInResponseTo(assertion)) { - String requestId = getAuthnRequestId(token.getAuthenticationRequest()); - params.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId); - } - params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience)); - params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient)); - params.put(SAML2AssertionValidationParameters.VALID_ISSUERS, Collections.singleton(assertingPartyEntityId)); - paramsConsumer.accept(params); - return new ValidationContext(params); - } - - private static boolean assertionContainsInResponseTo(Assertion assertion) { - if (assertion.getSubject() == null) { - return false; - } - for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) { - SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData(); - if (confirmationData == null) { - continue; - } - if (StringUtils.hasText(confirmationData.getInResponseTo())) { - return true; - } - } - return false; - } - - private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) { - return (serialized != null) ? serialized.getId() : null; - } - - private static class SAML20AssertionValidators { - - private static final Collection conditions = new ArrayList<>(); - - private static final Collection subjects = new ArrayList<>(); - - private static final Collection statements = new ArrayList<>(); - - private static final SignaturePrevalidator validator = new SAMLSignatureProfileValidator(); - - static { - conditions.add(new AudienceRestrictionConditionValidator()); - conditions.add(new DelegationRestrictionConditionValidator()); - conditions.add(new ConditionValidator() { - @Nonnull - @Override - public QName getServicedCondition() { - return OneTimeUse.DEFAULT_ELEMENT_NAME; - } - - @Nonnull - @Override - public ValidationResult validate(Condition condition, Assertion assertion, ValidationContext context) { - // applications should validate their own OneTimeUse conditions - return ValidationResult.VALID; - } - }); - conditions.add(new ProxyRestrictionConditionValidator()); - subjects.add(new BearerSubjectConfirmationValidator() { - @Override - protected ValidationResult validateAddress(SubjectConfirmation confirmation, Assertion assertion, - ValidationContext context, boolean required) { - // applications should validate their own addresses - gh-7514 - return ValidationResult.VALID; - } - }); - } - - private static final SAML20AssertionValidator attributeValidator = new SAML20AssertionValidator(conditions, - subjects, statements, null, null, null) { - @Nonnull - @Override - protected ValidationResult validateSignature(Assertion token, ValidationContext context) { - return ValidationResult.VALID; - } - }; - - static SAML20AssertionValidator createSignatureValidator(SignatureTrustEngine engine) { - return new SAML20AssertionValidator(new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), null, engine, - validator) { - @Nonnull - @Override - protected ValidationResult validateBasicData(@Nonnull Assertion assertion, - @Nonnull ValidationContext context) throws AssertionValidationException { - return ValidationResult.VALID; - } - - @Nonnull - @Override - protected ValidationResult validateConditions(Assertion assertion, ValidationContext context) { - return ValidationResult.VALID; - } - - @Nonnull - @Override - protected ValidationResult validateSubjectConfirmation(Assertion assertion, ValidationContext context) { - return ValidationResult.VALID; - } - - @Nonnull - @Override - protected ValidationResult validateStatements(Assertion assertion, ValidationContext context) { - return ValidationResult.VALID; - } - - @Override - protected ValidationResult validateIssuer(Assertion assertion, ValidationContext context) { - return ValidationResult.VALID; - } - }; - - } - - } - /** * A tuple containing an OpenSAML {@link Response} and its associated authentication * token. @@ -911,6 +374,11 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv this.response = response; } + ResponseToken(BaseOpenSamlAuthenticationProvider.ResponseToken token) { + this.token = token.getToken(); + this.response = token.getResponse(); + } + public Response getResponse() { return this.response; } @@ -938,6 +406,11 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv this.assertion = assertion; } + AssertionToken(BaseOpenSamlAuthenticationProvider.AssertionToken token) { + this.token = token.getToken(); + this.assertion = token.getAssertion(); + } + public Assertion getAssertion() { return this.assertion; } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4Template.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4Template.java new file mode 100644 index 0000000000..f529b8b4a0 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.utilities.java.support.resolver.CriteriaSet; +import net.shibboleth.utilities.java.support.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml4Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml4Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml4SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml4SerializationConfigurer serialize(Element element) { + return new OpenSaml4SerializationConfigurer(element); + } + + @Override + public OpenSaml4SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml4SignatureConfigurer(credentials); + } + + @Override + public OpenSaml4VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml4VerificationConfigurer(credentials); + } + + @Override + public OpenSaml4DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml4DecryptionConfigurer(credentials); + } + + OpenSaml4Template() { + + } + + static final class OpenSaml4SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml4SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml4SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml4SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml4SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml4SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml4VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml4VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml4DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml4DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlDecryptionUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlDecryptionUtils.java deleted file mode 100644 index e80bd0619b..0000000000 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlDecryptionUtils.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.authentication; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; - -import org.opensaml.saml.saml2.core.Assertion; -import org.opensaml.saml.saml2.core.Attribute; -import org.opensaml.saml.saml2.core.AttributeStatement; -import org.opensaml.saml.saml2.core.EncryptedAssertion; -import org.opensaml.saml.saml2.core.EncryptedAttribute; -import org.opensaml.saml.saml2.core.NameID; -import org.opensaml.saml.saml2.core.Response; -import org.opensaml.saml.saml2.encryption.Decrypter; -import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; -import org.opensaml.security.credential.Credential; -import org.opensaml.security.credential.CredentialSupport; -import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; -import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; -import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; -import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; -import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; -import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; - -import org.springframework.security.saml2.Saml2Exception; -import org.springframework.security.saml2.core.Saml2X509Credential; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; - -/** - * Utility methods for decrypting SAML components with OpenSAML - * - * For internal use only. - * - * @author Josh Cummings - */ -final class OpenSamlDecryptionUtils { - - private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( - Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), - new SimpleRetrievalMethodEncryptedKeyResolver())); - - static void decryptResponseElements(Response response, RelyingPartyRegistration registration) { - Decrypter decrypter = decrypter(registration); - for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) { - try { - Assertion assertion = decrypter.decrypt(encryptedAssertion); - response.getAssertions().add(assertion); - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - } - - static void decryptAssertionElements(Assertion assertion, RelyingPartyRegistration registration) { - Decrypter decrypter = decrypter(registration); - for (AttributeStatement statement : assertion.getAttributeStatements()) { - for (EncryptedAttribute encryptedAttribute : statement.getEncryptedAttributes()) { - try { - Attribute attribute = decrypter.decrypt(encryptedAttribute); - statement.getAttributes().add(attribute); - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - } - if (assertion.getSubject() == null) { - return; - } - if (assertion.getSubject().getEncryptedID() == null) { - return; - } - try { - assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID())); - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - - private static Decrypter decrypter(RelyingPartyRegistration registration) { - Collection credentials = new ArrayList<>(); - for (Saml2X509Credential key : registration.getDecryptionX509Credentials()) { - Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); - credentials.add(cred); - } - KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); - Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); - decrypter.setRootInNewDocument(true); - return decrypter; - } - - private OpenSamlDecryptionUtils() { - } - -} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlOperations.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlOperations.java new file mode 100644 index 0000000000..f2c055839d --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlOperations.java @@ -0,0 +1,184 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import javax.xml.namespace.QName; + +import org.opensaml.core.xml.XMLObject; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.web.util.UriComponentsBuilder; + +interface OpenSamlOperations { + + T build(QName elementName); + + T deserialize(String serialized); + + T deserialize(InputStream serialized); + + SerializationConfigurer serialize(XMLObject object); + + SerializationConfigurer serialize(Element element); + + SignatureConfigurer withSigningKeys(Collection credentials); + + VerificationConfigurer withVerificationKeys(Collection credentials); + + DecryptionConfigurer withDecryptionKeys(Collection credentials); + + interface SerializationConfigurer> { + + B prettyPrint(boolean pretty); + + String serialize(); + + } + + interface SignatureConfigurer> { + + B algorithms(List algs); + + O sign(O object); + + Map sign(Map params); + + } + + interface VerificationConfigurer { + + VerificationConfigurer entityId(String entityId); + + Collection verify(SignableXMLObject signable); + + Collection verify(VerificationConfigurer.RedirectParameters parameters); + + final class RedirectParameters { + + private final String id; + + private final Issuer issuer; + + private final String algorithm; + + private final byte[] signature; + + private final byte[] content; + + RedirectParameters(Map parameters, String parametersQuery, RequestAbstractType request) { + this.id = request.getID(); + this.issuer = request.getIssuer(); + this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG); + if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) { + this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE)); + } + else { + this.signature = null; + } + Map queryParams = UriComponentsBuilder.newInstance() + .query(parametersQuery) + .build(true) + .getQueryParams() + .toSingleValueMap(); + String relayState = parameters.get(Saml2ParameterNames.RELAY_STATE); + this.content = getContent(Saml2ParameterNames.SAML_REQUEST, relayState, queryParams); + } + + RedirectParameters(Map parameters, String parametersQuery, StatusResponseType response) { + this.id = response.getID(); + this.issuer = response.getIssuer(); + this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG); + if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) { + this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE)); + } + else { + this.signature = null; + } + Map queryParams = UriComponentsBuilder.newInstance() + .query(parametersQuery) + .build(true) + .getQueryParams() + .toSingleValueMap(); + String relayState = parameters.get(Saml2ParameterNames.RELAY_STATE); + this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams); + } + + static byte[] getContent(String samlObject, String relayState, final Map queryParams) { + if (Objects.nonNull(relayState)) { + return String + .format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject), + Saml2ParameterNames.RELAY_STATE, queryParams.get(Saml2ParameterNames.RELAY_STATE), + Saml2ParameterNames.SIG_ALG, queryParams.get(Saml2ParameterNames.SIG_ALG)) + .getBytes(StandardCharsets.UTF_8); + } + else { + return String + .format("%s=%s&%s=%s", samlObject, queryParams.get(samlObject), Saml2ParameterNames.SIG_ALG, + queryParams.get(Saml2ParameterNames.SIG_ALG)) + .getBytes(StandardCharsets.UTF_8); + } + } + + String getId() { + return this.id; + } + + Issuer getIssuer() { + return this.issuer; + } + + byte[] getContent() { + return this.content; + } + + String getAlgorithm() { + return this.algorithm; + } + + byte[] getSignature() { + return this.signature; + } + + boolean hasSignature() { + return this.signature != null; + } + + } + + } + + interface DecryptionConfigurer { + + void decrypt(XMLObject object); + + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlSigningUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlSigningUtils.java deleted file mode 100644 index 4861ebce7e..0000000000 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlSigningUtils.java +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.authentication; - -import java.nio.charset.StandardCharsets; -import java.security.PrivateKey; -import java.security.cert.X509Certificate; -import java.util.ArrayList; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -import net.shibboleth.utilities.java.support.resolver.CriteriaSet; -import net.shibboleth.utilities.java.support.xml.SerializeSupport; -import org.opensaml.core.xml.XMLObject; -import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; -import org.opensaml.core.xml.io.Marshaller; -import org.opensaml.core.xml.io.MarshallingException; -import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; -import org.opensaml.security.SecurityException; -import org.opensaml.security.credential.BasicCredential; -import org.opensaml.security.credential.Credential; -import org.opensaml.security.credential.CredentialSupport; -import org.opensaml.security.credential.UsageType; -import org.opensaml.xmlsec.SignatureSigningParameters; -import org.opensaml.xmlsec.SignatureSigningParametersResolver; -import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; -import org.opensaml.xmlsec.crypto.XMLSigningUtil; -import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; -import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; -import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; -import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; -import org.opensaml.xmlsec.signature.SignableXMLObject; -import org.opensaml.xmlsec.signature.support.SignatureConstants; -import org.opensaml.xmlsec.signature.support.SignatureSupport; -import org.w3c.dom.Element; - -import org.springframework.security.saml2.Saml2Exception; -import org.springframework.security.saml2.core.Saml2ParameterNames; -import org.springframework.security.saml2.core.Saml2X509Credential; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.util.Assert; -import org.springframework.web.util.UriComponentsBuilder; -import org.springframework.web.util.UriUtils; - -/** - * Utility methods for signing SAML components with OpenSAML - * - * For internal use only. - * - * @author Josh Cummings - */ -final class OpenSamlSigningUtils { - - static String serialize(XMLObject object) { - try { - Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); - Element element = marshaller.marshall(object); - return SerializeSupport.nodeToString(element); - } - catch (MarshallingException ex) { - throw new Saml2Exception(ex); - } - } - - static O sign(O object, RelyingPartyRegistration relyingPartyRegistration) { - SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration); - try { - SignatureSupport.signObject(object, parameters); - return object; - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - - static QueryParametersPartial sign(RelyingPartyRegistration registration) { - return new QueryParametersPartial(registration); - } - - private static SignatureSigningParameters resolveSigningParameters( - RelyingPartyRegistration relyingPartyRegistration) { - List credentials = resolveSigningCredentials(relyingPartyRegistration); - List algorithms = relyingPartyRegistration.getAssertingPartyMetadata().getSigningAlgorithms(); - List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); - String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; - SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); - CriteriaSet criteria = new CriteriaSet(); - BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); - signingConfiguration.setSigningCredentials(credentials); - signingConfiguration.setSignatureAlgorithms(algorithms); - signingConfiguration.setSignatureReferenceDigestMethods(digests); - signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); - signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); - criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration)); - try { - SignatureSigningParameters parameters = resolver.resolveSingle(criteria); - Assert.notNull(parameters, "Failed to resolve any signing credential"); - return parameters; - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - - private static NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { - final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); - - namedManager.setUseDefaultManager(true); - final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); - - // Generator for X509Credentials - final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); - x509Factory.setEmitEntityCertificate(true); - x509Factory.setEmitEntityCertificateChain(true); - - defaultManager.registerFactory(x509Factory); - - return namedManager; - } - - private static List resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) { - List credentials = new ArrayList<>(); - for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) { - X509Certificate certificate = x509Credential.getCertificate(); - PrivateKey privateKey = x509Credential.getPrivateKey(); - BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); - credential.setEntityId(relyingPartyRegistration.getEntityId()); - credential.setUsageType(UsageType.SIGNING); - credentials.add(credential); - } - return credentials; - } - - private OpenSamlSigningUtils() { - - } - - static class QueryParametersPartial { - - final RelyingPartyRegistration registration; - - final Map components = new LinkedHashMap<>(); - - QueryParametersPartial(RelyingPartyRegistration registration) { - this.registration = registration; - } - - QueryParametersPartial param(String key, String value) { - this.components.put(key, value); - return this; - } - - Map parameters() { - SignatureSigningParameters parameters = resolveSigningParameters(this.registration); - Credential credential = parameters.getSigningCredential(); - String algorithmUri = parameters.getSignatureAlgorithm(); - this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); - UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); - for (Map.Entry component : this.components.entrySet()) { - builder.queryParam(component.getKey(), - UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); - } - String queryString = builder.build(true).toString().substring(1); - try { - byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, - queryString.getBytes(StandardCharsets.UTF_8)); - String b64Signature = Saml2Utils.samlEncode(rawSignature); - this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); - } - catch (SecurityException ex) { - throw new Saml2Exception(ex); - } - return this.components; - } - - } - -} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlVerificationUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlVerificationUtils.java deleted file mode 100644 index 93928202a2..0000000000 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlVerificationUtils.java +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.authentication; - -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.Set; - -import jakarta.servlet.http.HttpServletRequest; -import net.shibboleth.utilities.java.support.resolver.CriteriaSet; -import org.opensaml.core.criterion.EntityIdCriterion; -import org.opensaml.saml.common.xml.SAMLConstants; -import org.opensaml.saml.criterion.ProtocolCriterion; -import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; -import org.opensaml.saml.saml2.core.Issuer; -import org.opensaml.saml.saml2.core.RequestAbstractType; -import org.opensaml.saml.saml2.core.StatusResponseType; -import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; -import org.opensaml.security.credential.Credential; -import org.opensaml.security.credential.CredentialResolver; -import org.opensaml.security.credential.UsageType; -import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; -import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; -import org.opensaml.security.credential.impl.CollectionCredentialResolver; -import org.opensaml.security.criteria.UsageCriterion; -import org.opensaml.security.x509.BasicX509Credential; -import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; -import org.opensaml.xmlsec.signature.Signature; -import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; -import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; - -import org.springframework.security.saml2.core.Saml2Error; -import org.springframework.security.saml2.core.Saml2ErrorCodes; -import org.springframework.security.saml2.core.Saml2ParameterNames; -import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; -import org.springframework.security.saml2.core.Saml2X509Credential; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.web.util.UriUtils; - -/** - * Utility methods for verifying SAML component signatures with OpenSAML - * - * For internal use only. - * - * @author Josh Cummings - */ - -final class OpenSamlVerificationUtils { - - static VerifierPartial verifySignature(StatusResponseType object, RelyingPartyRegistration registration) { - return new VerifierPartial(object, registration); - } - - static VerifierPartial verifySignature(RequestAbstractType object, RelyingPartyRegistration registration) { - return new VerifierPartial(object, registration); - } - - static SignatureTrustEngine trustEngine(RelyingPartyRegistration registration) { - Set credentials = new HashSet<>(); - Collection keys = registration.getAssertingPartyMetadata() - .getVerificationX509Credentials(); - for (Saml2X509Credential key : keys) { - BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); - cred.setUsageType(UsageType.SIGNING); - cred.setEntityId(registration.getAssertingPartyMetadata().getEntityId()); - credentials.add(cred); - } - CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); - return new ExplicitKeySignatureTrustEngine(credentialsResolver, - DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); - } - - private OpenSamlVerificationUtils() { - - } - - static class VerifierPartial { - - private final String id; - - private final CriteriaSet criteria; - - private final SignatureTrustEngine trustEngine; - - VerifierPartial(StatusResponseType object, RelyingPartyRegistration registration) { - this.id = object.getID(); - this.criteria = verificationCriteria(object.getIssuer()); - this.trustEngine = trustEngine(registration); - } - - VerifierPartial(RequestAbstractType object, RelyingPartyRegistration registration) { - this.id = object.getID(); - this.criteria = verificationCriteria(object.getIssuer()); - this.trustEngine = trustEngine(registration); - } - - Saml2ResponseValidatorResult redirect(HttpServletRequest request, String objectParameterName) { - RedirectSignature signature = new RedirectSignature(request, objectParameterName); - if (signature.getAlgorithm() == null) { - return Saml2ResponseValidatorResult.failure(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Missing signature algorithm for object [" + this.id + "]")); - } - if (!signature.hasSignature()) { - return Saml2ResponseValidatorResult.failure(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Missing signature for object [" + this.id + "]")); - } - Collection errors = new ArrayList<>(); - String algorithmUri = signature.getAlgorithm(); - try { - if (!this.trustEngine.validate(signature.getSignature(), signature.getContent(), algorithmUri, - this.criteria, null)) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for object [" + this.id + "]")); - } - } - catch (Exception ex) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for object [" + this.id + "]: ")); - } - return Saml2ResponseValidatorResult.failure(errors); - } - - Saml2ResponseValidatorResult post(Signature signature) { - Collection errors = new ArrayList<>(); - SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); - try { - profileValidator.validate(signature); - } - catch (Exception ex) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for object [" + this.id + "]: ")); - } - - try { - if (!this.trustEngine.validate(signature, this.criteria)) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for object [" + this.id + "]")); - } - } - catch (Exception ex) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for object [" + this.id + "]: ")); - } - - return Saml2ResponseValidatorResult.failure(errors); - } - - private CriteriaSet verificationCriteria(Issuer issuer) { - CriteriaSet criteria = new CriteriaSet(); - criteria.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue()))); - criteria.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS))); - criteria.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); - return criteria; - } - - private static class RedirectSignature { - - private final HttpServletRequest request; - - private final String objectParameterName; - - RedirectSignature(HttpServletRequest request, String objectParameterName) { - this.request = request; - this.objectParameterName = objectParameterName; - } - - String getAlgorithm() { - return this.request.getParameter(Saml2ParameterNames.SIG_ALG); - } - - byte[] getContent() { - if (this.request.getParameter(Saml2ParameterNames.RELAY_STATE) != null) { - return String - .format("%s=%s&%s=%s&%s=%s", this.objectParameterName, UriUtils - .encode(this.request.getParameter(this.objectParameterName), StandardCharsets.ISO_8859_1), - Saml2ParameterNames.RELAY_STATE, - UriUtils.encode(this.request.getParameter(Saml2ParameterNames.RELAY_STATE), - StandardCharsets.ISO_8859_1), - Saml2ParameterNames.SIG_ALG, - UriUtils.encode(getAlgorithm(), StandardCharsets.ISO_8859_1)) - .getBytes(StandardCharsets.UTF_8); - } - else { - return String - .format("%s=%s&%s=%s", this.objectParameterName, - UriUtils.encode(this.request.getParameter(this.objectParameterName), - StandardCharsets.ISO_8859_1), - Saml2ParameterNames.SIG_ALG, - UriUtils.encode(getAlgorithm(), StandardCharsets.ISO_8859_1)) - .getBytes(StandardCharsets.UTF_8); - } - } - - byte[] getSignature() { - return Saml2Utils.samlDecode(this.request.getParameter(Saml2ParameterNames.SIGNATURE)); - } - - boolean hasSignature() { - return this.request.getParameter(Saml2ParameterNames.SIGNATURE) != null; - } - - } - - } - -} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Utils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Utils.java index 1d1012f702..acbabd1e0d 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Utils.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Utils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.authentication; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.Base64; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; @@ -28,7 +29,11 @@ import java.util.zip.InflaterOutputStream; import org.springframework.security.saml2.Saml2Exception; /** - * @since 5.3 + * Utility methods for working with serialized SAML messages. + * + * For internal use only. + * + * @author Josh Cummings */ final class Saml2Utils { @@ -69,4 +74,123 @@ final class Saml2Utils { } } + static EncodingConfigurer withDecoded(String decoded) { + return new EncodingConfigurer(decoded); + } + + static DecodingConfigurer withEncoded(String encoded) { + return new DecodingConfigurer(encoded); + } + + static final class EncodingConfigurer { + + private final String decoded; + + private boolean deflate; + + private EncodingConfigurer(String decoded) { + this.decoded = decoded; + } + + EncodingConfigurer deflate(boolean deflate) { + this.deflate = deflate; + return this; + } + + String encode() { + byte[] bytes = (this.deflate) ? Saml2Utils.samlDeflate(this.decoded) + : this.decoded.getBytes(StandardCharsets.UTF_8); + return Saml2Utils.samlEncode(bytes); + } + + } + + static final class DecodingConfigurer { + + private static final Base64Checker BASE_64_CHECKER = new Base64Checker(); + + private final String encoded; + + private boolean inflate; + + private boolean requireBase64; + + private DecodingConfigurer(String encoded) { + this.encoded = encoded; + } + + DecodingConfigurer inflate(boolean inflate) { + this.inflate = inflate; + return this; + } + + DecodingConfigurer requireBase64(boolean requireBase64) { + this.requireBase64 = requireBase64; + return this; + } + + String decode() { + if (this.requireBase64) { + BASE_64_CHECKER.checkAcceptable(this.encoded); + } + byte[] bytes = Saml2Utils.samlDecode(this.encoded); + return (this.inflate) ? Saml2Utils.samlInflate(bytes) : new String(bytes, StandardCharsets.UTF_8); + } + + static class Base64Checker { + + private static final int[] values = genValueMapping(); + + Base64Checker() { + + } + + private static int[] genValueMapping() { + byte[] alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + .getBytes(StandardCharsets.ISO_8859_1); + + int[] values = new int[256]; + Arrays.fill(values, -1); + for (int i = 0; i < alphabet.length; i++) { + values[alphabet[i] & 0xff] = i; + } + return values; + } + + boolean isAcceptable(String s) { + int goodChars = 0; + int lastGoodCharVal = -1; + + // count number of characters from Base64 alphabet + for (int i = 0; i < s.length(); i++) { + int val = values[0xff & s.charAt(i)]; + if (val != -1) { + lastGoodCharVal = val; + goodChars++; + } + } + + // in cases of an incomplete final chunk, ensure the unused bits are zero + switch (goodChars % 4) { + case 0: + return true; + case 2: + return (lastGoodCharVal & 0b1111) == 0; + case 3: + return (lastGoodCharVal & 0b11) == 0; + default: + return false; + } + } + + void checkAcceptable(String ins) { + if (!isAcceptable(ins)) { + throw new IllegalArgumentException("Failed to decode SAMLResponse"); + } + } + + } + + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index 0193a53147..1713de3bae 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -32,12 +32,9 @@ import java.util.function.Consumer; import javax.xml.namespace.QName; import com.fasterxml.jackson.databind.ObjectMapper; -import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.junit.jupiter.api.Test; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; -import org.opensaml.core.xml.io.Marshaller; -import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.core.xml.schema.XSDateTime; import org.opensaml.core.xml.schema.impl.XSDateTimeBuilder; import org.opensaml.saml.common.SignableSAMLObject; @@ -68,12 +65,10 @@ import org.opensaml.saml.saml2.core.impl.StatusBuilder; import org.opensaml.saml.saml2.core.impl.StatusCodeBuilder; import org.opensaml.xmlsec.encryption.impl.EncryptedDataBuilder; import org.opensaml.xmlsec.signature.support.SignatureConstants; -import org.w3c.dom.Element; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.Authentication; import org.springframework.security.jackson2.SecurityJackson2Modules; -import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; @@ -107,6 +102,8 @@ public class OpenSaml4AuthenticationProviderTests { private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp"; + private final OpenSamlOperations saml = new OpenSaml4Template(); + private OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider(); private Saml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("name", @@ -839,14 +836,7 @@ public class OpenSaml4AuthenticationProviderTests { } private String serialize(XMLObject object) { - try { - Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); - Element element = marshaller.marshall(object); - return SerializeSupport.nodeToString(element); - } - catch (MarshallingException ex) { - throw new Saml2Exception(ex); - } + return this.saml.serialize(object).serialize(); } private Consumer errorOf(String errorCode) { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlSigningUtilsTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlSigningUtilsTests.java deleted file mode 100644 index 080bbb0919..0000000000 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlSigningUtilsTests.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.authentication; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.opensaml.saml.saml2.core.Response; -import org.opensaml.xmlsec.signature.Signature; - -import org.springframework.security.saml2.core.TestSaml2X509Credentials; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Test open SAML signatures - */ -public class OpenSamlSigningUtilsTests { - - private RelyingPartyRegistration registration; - - @BeforeEach - public void setup() { - this.registration = RelyingPartyRegistration.withRegistrationId("saml-idp") - .entityId("https://some.idp.example.com/entity-id") - .signingX509Credentials((c) -> { - c.add(TestSaml2X509Credentials.relyingPartySigningCredential()); - c.add(TestSaml2X509Credentials.assertingPartySigningCredential()); - }) - .assertingPartyDetails((c) -> c.entityId("https://some.idp.example.com/entity-id") - .singleSignOnServiceLocation("https://some.idp.example.com/service-location")) - .build(); - } - - @Test - public void whenSigningAnObjectThenKeyInfoIsPartOfTheSignature() throws Exception { - Response response = TestOpenSamlObjects.response(); - OpenSamlSigningUtils.sign(response, this.registration); - Signature signature = response.getSignature(); - assertThat(signature).isNotNull(); - assertThat(signature.getKeyInfo()).isNotNull(); - } - -}