Use OpenSAML API in authentication

Issue gh-11658
This commit is contained in:
Josh Cummings 2024-08-02 18:50:28 -06:00
parent 416859e70e
commit 80b31820cd
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
11 changed files with 1652 additions and 1167 deletions

View File

@ -13,6 +13,12 @@ sourceSets.configureEach { set ->
filter { line -> line.replaceAll(".saml2.internal", ".saml2.provider.service.authentication.logout") }
with from
}
copy {
into "$projectDir/src/$set.name/java/org/springframework/security/saml2/provider/service/authentication"
filter { line -> line.replaceAll(".saml2.internal", ".saml2.provider.service.authentication") }
with from
}
}
dependencies {

View File

@ -0,0 +1,678 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.authentication;
import java.lang.reflect.Field;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import javax.annotation.Nonnull;
import javax.xml.namespace.QName;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.schema.XSBoolean;
import org.opensaml.core.xml.schema.XSBooleanValue;
import org.opensaml.core.xml.schema.XSDateTime;
import org.opensaml.core.xml.schema.XSInteger;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.core.xml.schema.XSURI;
import org.opensaml.saml.common.assertion.AssertionValidationException;
import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.common.assertion.ValidationResult;
import org.opensaml.saml.saml2.assertion.ConditionValidator;
import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator;
import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters;
import org.opensaml.saml.saml2.assertion.StatementValidator;
import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator;
import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator;
import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator;
import org.opensaml.saml.saml2.assertion.impl.DelegationRestrictionConditionValidator;
import org.opensaml.saml.saml2.assertion.impl.ProxyRestrictionConditionValidator;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AuthnStatement;
import org.opensaml.saml.saml2.core.Condition;
import org.opensaml.saml.saml2.core.OneTimeUse;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.StatusCode;
import org.opensaml.saml.saml2.core.SubjectConfirmation;
import org.opensaml.saml.saml2.core.SubjectConfirmationData;
import org.springframework.core.convert.converter.Converter;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
static {
OpenSamlInitializationService.initialize();
}
private final Log logger = LogFactory.getLog(this.getClass());
private final OpenSamlOperations saml;
private final Converter<ResponseToken, Saml2ResponseValidatorResult> responseSignatureValidator = createDefaultResponseSignatureValidator();
private Consumer<ResponseToken> responseElementsDecrypter = createDefaultResponseElementsDecrypter();
private Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator = createDefaultResponseValidator();
private final Converter<AssertionToken, Saml2ResponseValidatorResult> assertionSignatureValidator = createDefaultAssertionSignatureValidator();
private Consumer<AssertionToken> assertionElementsDecrypter = createDefaultAssertionElementsDecrypter();
private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator = createDefaultAssertionValidator();
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
private static final Set<String> includeChildStatusCodes = new HashSet<>(
Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH));
BaseOpenSamlAuthenticationProvider(OpenSamlOperations saml) {
this.saml = saml;
}
void setResponseElementsDecrypter(Consumer<ResponseToken> responseElementsDecrypter) {
Assert.notNull(responseElementsDecrypter, "responseElementsDecrypter cannot be null");
this.responseElementsDecrypter = responseElementsDecrypter;
}
void setResponseValidator(Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator) {
Assert.notNull(responseValidator, "responseValidator cannot be null");
this.responseValidator = responseValidator;
}
void setAssertionValidator(Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator) {
Assert.notNull(assertionValidator, "assertionValidator cannot be null");
this.assertionValidator = assertionValidator;
}
void setAssertionElementsDecrypter(Consumer<AssertionToken> assertionDecrypter) {
Assert.notNull(assertionDecrypter, "assertionDecrypter cannot be null");
this.assertionElementsDecrypter = assertionDecrypter;
}
void setResponseAuthenticationConverter(
Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter) {
Assert.notNull(responseAuthenticationConverter, "responseAuthenticationConverter cannot be null");
this.responseAuthenticationConverter = responseAuthenticationConverter;
}
static Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseValidator() {
return (responseToken) -> {
Response response = responseToken.getResponse();
Saml2AuthenticationToken token = responseToken.getToken();
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
List<String> statusCodes = getStatusCodes(response);
if (!isSuccess(statusCodes)) {
for (String statusCode : statusCodes) {
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
response.getID());
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
}
}
String inResponseTo = response.getInResponseTo();
result = result.concat(validateInResponseTo(token.getAuthenticationRequest(), inResponseTo));
String issuer = response.getIssuer().getValue();
String destination = response.getDestination();
String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
if (StringUtils.hasText(destination) && !destination.equals(location)) {
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID()
+ "]";
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
}
String assertingPartyEntityId = token.getRelyingPartyRegistration()
.getAssertingPartyMetadata()
.getEntityId();
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
}
if (response.getAssertions().isEmpty()) {
result = result.concat(
new Saml2Error(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response."));
}
return result;
};
}
private static List<String> getStatusCodes(Response response) {
if (response.getStatus() == null) {
return List.of(StatusCode.SUCCESS);
}
if (response.getStatus().getStatusCode() == null) {
return List.of(StatusCode.SUCCESS);
}
StatusCode parentStatusCode = response.getStatus().getStatusCode();
String parentStatusCodeValue = parentStatusCode.getValue();
if (!includeChildStatusCodes.contains(parentStatusCodeValue)) {
return List.of(parentStatusCodeValue);
}
StatusCode childStatusCode = parentStatusCode.getStatusCode();
if (childStatusCode == null) {
return List.of(parentStatusCodeValue);
}
String childStatusCodeValue = childStatusCode.getValue();
if (childStatusCodeValue == null) {
return List.of(parentStatusCodeValue);
}
return List.of(parentStatusCodeValue, childStatusCodeValue);
}
private static boolean isSuccess(List<String> statusCodes) {
if (statusCodes.size() != 1) {
return false;
}
String statusCode = statusCodes.get(0);
return StatusCode.SUCCESS.equals(statusCode);
}
private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
String inResponseTo) {
if (!StringUtils.hasText(inResponseTo)) {
return Saml2ResponseValidatorResult.success();
}
if (storedRequest == null) {
String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
+ " but no saved authentication request was found";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
}
if (!inResponseTo.equals(storedRequest.getId())) {
String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
+ "authentication request [" + storedRequest.getId() + "]";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
}
return Saml2ResponseValidatorResult.success();
}
static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator() {
return createDefaultAssertionValidatorWithParameters(
(params) -> params.put(SAML2AssertionValidationParameters.CLOCK_SKEW, Duration.ofMinutes(5)));
}
static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator(
Converter<AssertionToken, ValidationContext> contextConverter) {
return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
(assertionToken) -> SAML20AssertionValidators.attributeValidator, contextConverter);
}
static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidatorWithParameters(
Consumer<Map<String, Object>> validationContextParameters) {
return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
(assertionToken) -> SAML20AssertionValidators.attributeValidator,
(assertionToken) -> createValidationContext(assertionToken, validationContextParameters));
}
static Converter<ResponseToken, Saml2Authentication> createDefaultResponseAuthenticationConverter() {
return (responseToken) -> {
Response response = responseToken.response;
Saml2AuthenticationToken token = responseToken.token;
Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
String username = assertion.getSubject().getNameID().getValue();
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
List<String> sessionIndexes = getSessionIndexes(assertion);
DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal(username, attributes,
sessionIndexes);
String registrationId = responseToken.token.getRelyingPartyRegistration().getRegistrationId();
principal.setRelyingPartyRegistrationId(registrationId);
return new Saml2Authentication(principal, token.getSaml2Response(),
AuthorityUtils.createAuthorityList("ROLE_USER"));
};
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
try {
Saml2AuthenticationToken token = (Saml2AuthenticationToken) authentication;
String serializedResponse = token.getSaml2Response();
Response response = parseResponse(serializedResponse);
process(token, response);
AbstractAuthenticationToken authenticationResponse = this.responseAuthenticationConverter
.convert(new ResponseToken(response, token));
if (authenticationResponse != null) {
authenticationResponse.setDetails(authentication.getDetails());
}
return authenticationResponse;
}
catch (Saml2AuthenticationException ex) {
throw ex;
}
catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, ex.getMessage(), ex);
}
}
@Override
public boolean supports(Class<?> authentication) {
return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication);
}
private Response parseResponse(String response) throws Saml2Exception, Saml2AuthenticationException {
try {
return this.saml.deserialize(response);
}
catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, ex.getMessage(), ex);
}
}
private void process(Saml2AuthenticationToken token, Response response) {
String issuer = response.getIssuer().getValue();
this.logger.debug(LogMessage.format("Processing SAML response from %s", issuer));
boolean responseSigned = response.isSigned();
ResponseToken responseToken = new ResponseToken(response, token);
Saml2ResponseValidatorResult result = this.responseSignatureValidator.convert(responseToken);
if (responseSigned) {
this.responseElementsDecrypter.accept(responseToken);
}
else if (!response.getEncryptedAssertions().isEmpty()) {
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Did not decrypt response [" + response.getID() + "] since it is not signed"));
}
result = result.concat(this.responseValidator.convert(responseToken));
boolean allAssertionsSigned = true;
for (Assertion assertion : response.getAssertions()) {
AssertionToken assertionToken = new AssertionToken(assertion, token);
result = result.concat(this.assertionSignatureValidator.convert(assertionToken));
allAssertionsSigned = allAssertionsSigned && assertion.isSigned();
if (responseSigned || assertion.isSigned()) {
this.assertionElementsDecrypter.accept(new AssertionToken(assertion, token));
}
result = result.concat(this.assertionValidator.convert(assertionToken));
}
if (!responseSigned && !allAssertionsSigned) {
String description = "Either the response or one of the assertions is unsigned. "
+ "Please either sign the response or all of the assertions.";
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, description));
}
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
if (firstAssertion != null && !hasName(firstAssertion)) {
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
"Assertion [" + firstAssertion.getID() + "] is missing a subject");
result = result.concat(error);
}
if (result.hasErrors()) {
Collection<Saml2Error> errors = result.getErrors();
if (this.logger.isTraceEnabled()) {
this.logger.trace("Found " + errors.size() + " validation errors in SAML response [" + response.getID()
+ "]: " + errors);
}
else if (this.logger.isDebugEnabled()) {
this.logger
.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]");
}
Saml2Error first = errors.iterator().next();
throw createAuthenticationException(first.getErrorCode(), first.getDescription(), null);
}
else {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Successfully processed SAML Response [" + response.getID() + "]");
}
}
}
private Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseSignatureValidator() {
return (responseToken) -> {
Response response = responseToken.getResponse();
RelyingPartyRegistration registration = responseToken.getToken().getRelyingPartyRegistration();
if (response.isSigned()) {
AssertingPartyMetadata details = registration.getAssertingPartyMetadata();
Collection<Saml2X509Credential> credentials = details.getVerificationX509Credentials();
Collection<Saml2Error> errors = this.saml.withVerificationKeys(credentials)
.entityId(details.getEntityId())
.verify(response);
return Saml2ResponseValidatorResult.failure(errors);
}
return Saml2ResponseValidatorResult.success();
};
}
private Consumer<ResponseToken> createDefaultResponseElementsDecrypter() {
return (responseToken) -> {
Response response = responseToken.getResponse();
RelyingPartyRegistration registration = responseToken.getToken().getRelyingPartyRegistration();
try {
this.saml.withDecryptionKeys(registration.getDecryptionX509Credentials()).decrypt(response);
}
catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
};
}
private Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionSignatureValidator() {
return (assertionToken) -> {
RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration();
Assertion assertion = assertionToken.getAssertion();
if (assertion.isSigned()) {
AssertingPartyMetadata details = registration.getAssertingPartyMetadata();
Collection<Saml2X509Credential> credentials = details.getVerificationX509Credentials();
Collection<Saml2Error> errors = this.saml.withVerificationKeys(credentials)
.entityId(details.getEntityId())
.verify(assertion);
return Saml2ResponseValidatorResult.failure(errors);
}
return Saml2ResponseValidatorResult.success();
};
}
private Consumer<AssertionToken> createDefaultAssertionElementsDecrypter() {
return (assertionToken) -> {
Assertion assertion = assertionToken.getAssertion();
RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration();
try {
this.saml.withDecryptionKeys(registration.getDecryptionX509Credentials()).decrypt(assertion);
}
catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
};
}
private boolean hasName(Assertion assertion) {
if (assertion == null) {
return false;
}
if (assertion.getSubject() == null) {
return false;
}
if (assertion.getSubject().getNameID() == null) {
return false;
}
return assertion.getSubject().getNameID().getValue() != null;
}
private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
MultiValueMap<String, Object> attributeMap = new LinkedMultiValueMap<>();
for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
for (Attribute attribute : attributeStatement.getAttributes()) {
List<Object> attributeValues = new ArrayList<>();
for (XMLObject xmlObject : attribute.getAttributeValues()) {
Object attributeValue = getXmlObjectValue(xmlObject);
if (attributeValue != null) {
attributeValues.add(attributeValue);
}
}
attributeMap.addAll(attribute.getName(), attributeValues);
}
}
return new LinkedHashMap<>(attributeMap); // gh-11785
}
private static List<String> getSessionIndexes(Assertion assertion) {
List<String> sessionIndexes = new ArrayList<>();
for (AuthnStatement statement : assertion.getAuthnStatements()) {
sessionIndexes.add(statement.getSessionIndex());
}
return sessionIndexes;
}
private static Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject instanceof XSAny) {
return ((XSAny) xmlObject).getTextContent();
}
if (xmlObject instanceof XSString) {
return ((XSString) xmlObject).getValue();
}
if (xmlObject instanceof XSInteger) {
return ((XSInteger) xmlObject).getValue();
}
if (xmlObject instanceof XSURI) {
return ((XSURI) xmlObject).getURI();
}
if (xmlObject instanceof XSBoolean) {
XSBooleanValue xsBooleanValue = ((XSBoolean) xmlObject).getValue();
return (xsBooleanValue != null) ? xsBooleanValue.getValue() : null;
}
if (xmlObject instanceof XSDateTime) {
return ((XSDateTime) xmlObject).getValue();
}
return xmlObject;
}
private static Saml2AuthenticationException createAuthenticationException(String code, String message,
Exception cause) {
return new Saml2AuthenticationException(new Saml2Error(code, message), cause);
}
private static Converter<AssertionToken, Saml2ResponseValidatorResult> createAssertionValidator(String errorCode,
Converter<AssertionToken, SAML20AssertionValidator> validatorConverter,
Converter<AssertionToken, ValidationContext> contextConverter) {
return (assertionToken) -> {
Assertion assertion = assertionToken.assertion;
SAML20AssertionValidator validator = validatorConverter.convert(assertionToken);
ValidationContext context = contextConverter.convert(assertionToken);
try {
ValidationResult result = validator.validate(assertion, context);
if (result == ValidationResult.VALID) {
return Saml2ResponseValidatorResult.success();
}
}
catch (Exception ex) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(),
((Response) assertion.getParent()).getID(), ex.getMessage());
return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message));
}
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(),
((Response) assertion.getParent()).getID(), contextToString(context));
return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message));
};
}
private static String contextToString(ValidationContext context) {
StringBuilder sb = new StringBuilder();
for (Field field : context.getClass().getDeclaredFields()) {
ReflectionUtils.makeAccessible(field);
Object value = ReflectionUtils.getField(field, context);
sb.append(field.getName() + " = " + value + ",");
}
sb.deleteCharAt(sb.length() - 1);
return sb.toString();
}
private static ValidationContext createValidationContext(AssertionToken assertionToken,
Consumer<Map<String, Object>> paramsConsumer) {
Saml2AuthenticationToken token = assertionToken.token;
RelyingPartyRegistration relyingPartyRegistration = token.getRelyingPartyRegistration();
String audience = relyingPartyRegistration.getEntityId();
String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation();
String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyMetadata().getEntityId();
Map<String, Object> params = new HashMap<>();
Assertion assertion = assertionToken.getAssertion();
if (assertionContainsInResponseTo(assertion)) {
String requestId = getAuthnRequestId(token.getAuthenticationRequest());
params.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId);
}
params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience));
params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient));
params.put(SAML2AssertionValidationParameters.VALID_ISSUERS, Collections.singleton(assertingPartyEntityId));
paramsConsumer.accept(params);
return new ValidationContext(params);
}
private static boolean assertionContainsInResponseTo(Assertion assertion) {
if (assertion.getSubject() == null) {
return false;
}
for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) {
SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData();
if (confirmationData == null) {
continue;
}
if (StringUtils.hasText(confirmationData.getInResponseTo())) {
return true;
}
}
return false;
}
private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) {
return (serialized != null) ? serialized.getId() : null;
}
private static class SAML20AssertionValidators {
private static final Collection<ConditionValidator> conditions = new ArrayList<>();
private static final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
private static final Collection<StatementValidator> statements = new ArrayList<>();
static {
conditions.add(new AudienceRestrictionConditionValidator());
conditions.add(new DelegationRestrictionConditionValidator());
conditions.add(new ConditionValidator() {
@Nonnull
@Override
public QName getServicedCondition() {
return OneTimeUse.DEFAULT_ELEMENT_NAME;
}
@Nonnull
@Override
public ValidationResult validate(Condition condition, Assertion assertion, ValidationContext context) {
// applications should validate their own OneTimeUse conditions
return ValidationResult.VALID;
}
});
conditions.add(new ProxyRestrictionConditionValidator());
subjects.add(new BearerSubjectConfirmationValidator() {
@Nonnull
protected ValidationResult validateAddress(@Nonnull SubjectConfirmation confirmation,
@Nonnull Assertion assertion, @Nonnull ValidationContext context, boolean required)
throws AssertionValidationException {
return ValidationResult.VALID;
}
@Nonnull
protected ValidationResult validateAddress(@Nonnull SubjectConfirmationData confirmationData,
@Nonnull Assertion assertion, @Nonnull ValidationContext context, boolean required)
throws AssertionValidationException {
// applications should validate their own addresses - gh-7514
return ValidationResult.VALID;
}
});
}
private static final SAML20AssertionValidator attributeValidator = new SAML20AssertionValidator(conditions,
subjects, statements, null, null, null) {
@Nonnull
@Override
protected ValidationResult validateSignature(Assertion token, ValidationContext context) {
return ValidationResult.VALID;
}
};
}
/**
* A tuple containing an OpenSAML {@link Response} and its associated authentication
* token.
*
* @since 5.4
*/
static class ResponseToken {
private final Saml2AuthenticationToken token;
private final Response response;
ResponseToken(Response response, Saml2AuthenticationToken token) {
this.token = token;
this.response = response;
}
Response getResponse() {
return this.response;
}
Saml2AuthenticationToken getToken() {
return this.token;
}
}
/**
* A tuple containing an OpenSAML {@link Assertion} and its associated authentication
* token.
*
* @since 5.4
*/
static class AssertionToken {
private final Saml2AuthenticationToken token;
private final Assertion assertion;
AssertionToken(Assertion assertion, Saml2AuthenticationToken token) {
this.token = token;
this.assertion = assertion;
}
Assertion getAssertion() {
return this.assertion;
}
Saml2AuthenticationToken getToken() {
return this.token;
}
}
}

View File

@ -16,88 +16,24 @@
package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import javax.annotation.Nonnull;
import javax.xml.namespace.QName;
import net.shibboleth.utilities.java.support.xml.ParserPool;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.schema.XSBoolean;
import org.opensaml.core.xml.schema.XSBooleanValue;
import org.opensaml.core.xml.schema.XSDateTime;
import org.opensaml.core.xml.schema.XSInteger;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.core.xml.schema.XSURI;
import org.opensaml.saml.common.assertion.AssertionValidationException;
import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.common.assertion.ValidationResult;
import org.opensaml.saml.saml2.assertion.ConditionValidator;
import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator;
import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters;
import org.opensaml.saml.saml2.assertion.StatementValidator;
import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator;
import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator;
import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator;
import org.opensaml.saml.saml2.assertion.impl.DelegationRestrictionConditionValidator;
import org.opensaml.saml.saml2.assertion.impl.ProxyRestrictionConditionValidator;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.AuthnStatement;
import org.opensaml.saml.saml2.core.Condition;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.OneTimeUse;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.StatusCode;
import org.opensaml.saml.saml2.core.SubjectConfirmation;
import org.opensaml.saml.saml2.core.SubjectConfirmationData;
import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller;
import org.opensaml.saml.saml2.encryption.Decrypter;
import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
import org.opensaml.xmlsec.signature.support.SignaturePrevalidator;
import org.opensaml.xmlsec.signature.support.SignatureTrustEngine;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.springframework.core.convert.converter.Converter;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* Implementation of {@link AuthenticationProvider} for SAML authentications when
@ -142,48 +78,13 @@ import org.springframework.util.StringUtils;
*/
public final class OpenSaml4AuthenticationProvider implements AuthenticationProvider {
static {
OpenSamlInitializationService.initialize();
}
private final Log logger = LogFactory.getLog(this.getClass());
private final ResponseUnmarshaller responseUnmarshaller;
private static final AuthnRequestUnmarshaller authnRequestUnmarshaller;
static {
XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
authnRequestUnmarshaller = (AuthnRequestUnmarshaller) registry.getUnmarshallerFactory()
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
}
private final ParserPool parserPool;
private final Converter<ResponseToken, Saml2ResponseValidatorResult> responseSignatureValidator = createDefaultResponseSignatureValidator();
private Consumer<ResponseToken> responseElementsDecrypter = createDefaultResponseElementsDecrypter();
private Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator = createDefaultResponseValidator();
private final Converter<AssertionToken, Saml2ResponseValidatorResult> assertionSignatureValidator = createDefaultAssertionSignatureValidator();
private Consumer<AssertionToken> assertionElementsDecrypter = createDefaultAssertionElementsDecrypter();
private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator = createDefaultAssertionValidator();
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
private static final Set<String> includeChildStatusCodes = new HashSet<>(
Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH));
private final BaseOpenSamlAuthenticationProvider delegate;
/**
* Creates an {@link OpenSaml4AuthenticationProvider}
*/
public OpenSaml4AuthenticationProvider() {
XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
this.responseUnmarshaller = (ResponseUnmarshaller) registry.getUnmarshallerFactory()
.getUnmarshaller(Response.DEFAULT_ELEMENT_NAME);
this.parserPool = registry.getParserPool();
this.delegate = new BaseOpenSamlAuthenticationProvider(new OpenSaml4Template());
}
/**
@ -231,7 +132,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
*/
public void setResponseElementsDecrypter(Consumer<ResponseToken> responseElementsDecrypter) {
Assert.notNull(responseElementsDecrypter, "responseElementsDecrypter cannot be null");
this.responseElementsDecrypter = responseElementsDecrypter;
this.delegate
.setResponseElementsDecrypter((token) -> responseElementsDecrypter.accept(new ResponseToken(token)));
}
/**
@ -253,7 +155,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
*/
public void setResponseValidator(Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator) {
Assert.notNull(responseValidator, "responseValidator cannot be null");
this.responseValidator = responseValidator;
this.delegate.setResponseValidator((token) -> responseValidator.convert(new ResponseToken(token)));
}
/**
@ -297,7 +199,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
*/
public void setAssertionValidator(Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator) {
Assert.notNull(assertionValidator, "assertionValidator cannot be null");
this.assertionValidator = assertionValidator;
this.delegate.setAssertionValidator((token) -> assertionValidator.convert(new AssertionToken(token)));
}
/**
@ -340,7 +242,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
*/
public void setAssertionElementsDecrypter(Consumer<AssertionToken> assertionDecrypter) {
Assert.notNull(assertionDecrypter, "assertionDecrypter cannot be null");
this.assertionElementsDecrypter = assertionDecrypter;
this.delegate.setAssertionElementsDecrypter((token) -> assertionDecrypter.accept(new AssertionToken(token)));
}
/**
@ -366,7 +268,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
public void setResponseAuthenticationConverter(
Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter) {
Assert.notNull(responseAuthenticationConverter, "responseAuthenticationConverter cannot be null");
this.responseAuthenticationConverter = responseAuthenticationConverter;
this.delegate.setResponseAuthenticationConverter(
(token) -> responseAuthenticationConverter.convert(new ResponseToken(token)));
}
/**
@ -375,95 +278,10 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
* @since 5.6
*/
public static Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseValidator() {
return (responseToken) -> {
Response response = responseToken.getResponse();
Saml2AuthenticationToken token = responseToken.getToken();
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
List<String> statusCodes = getStatusCodes(response);
if (!isSuccess(statusCodes)) {
for (String statusCode : statusCodes) {
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
response.getID());
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
}
}
String inResponseTo = response.getInResponseTo();
result = result.concat(validateInResponseTo(token.getAuthenticationRequest(), inResponseTo));
String issuer = response.getIssuer().getValue();
String destination = response.getDestination();
String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
if (StringUtils.hasText(destination) && !destination.equals(location)) {
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID()
+ "]";
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
}
String assertingPartyEntityId = token.getRelyingPartyRegistration()
.getAssertingPartyMetadata()
.getEntityId();
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
}
if (response.getAssertions().isEmpty()) {
result = result.concat(
new Saml2Error(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response."));
}
return result;
};
}
private static List<String> getStatusCodes(Response response) {
if (response.getStatus() == null) {
return List.of(StatusCode.SUCCESS);
}
if (response.getStatus().getStatusCode() == null) {
return List.of(StatusCode.SUCCESS);
}
StatusCode parentStatusCode = response.getStatus().getStatusCode();
String parentStatusCodeValue = parentStatusCode.getValue();
if (!includeChildStatusCodes.contains(parentStatusCodeValue)) {
return List.of(parentStatusCodeValue);
}
StatusCode childStatusCode = parentStatusCode.getStatusCode();
if (childStatusCode == null) {
return List.of(parentStatusCodeValue);
}
String childStatusCodeValue = childStatusCode.getValue();
if (childStatusCodeValue == null) {
return List.of(parentStatusCodeValue);
}
return List.of(parentStatusCodeValue, childStatusCodeValue);
}
private static boolean isSuccess(List<String> statusCodes) {
if (statusCodes.size() != 1) {
return false;
}
String statusCode = statusCodes.get(0);
return StatusCode.SUCCESS.equals(statusCode);
}
private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
String inResponseTo) {
if (!StringUtils.hasText(inResponseTo)) {
return Saml2ResponseValidatorResult.success();
}
if (storedRequest == null) {
String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
+ " but no saved authentication request was found";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
}
if (!inResponseTo.equals(storedRequest.getId())) {
String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
+ "authentication request [" + storedRequest.getId() + "]";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
}
return Saml2ResponseValidatorResult.success();
Converter<BaseOpenSamlAuthenticationProvider.ResponseToken, Saml2ResponseValidatorResult> delegate = BaseOpenSamlAuthenticationProvider
.createDefaultResponseValidator();
return (token) -> delegate
.convert(new BaseOpenSamlAuthenticationProvider.ResponseToken(token.getResponse(), token.getToken()));
}
/**
@ -472,7 +290,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
* @return the default assertion validator strategy
*/
public static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator() {
return createDefaultAssertionValidatorWithParameters(
(params) -> params.put(SAML2AssertionValidationParameters.CLOCK_SKEW, Duration.ofMinutes(5)));
}
@ -488,9 +305,12 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
@Deprecated
public static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator(
Converter<AssertionToken, ValidationContext> contextConverter) {
return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
(assertionToken) -> SAML20AssertionValidators.attributeValidator, contextConverter);
Converter<BaseOpenSamlAuthenticationProvider.AssertionToken, ValidationContext> contextDelegate = (
token) -> contextConverter.convert(new AssertionToken(token.getAssertion(), token.getToken()));
Converter<BaseOpenSamlAuthenticationProvider.AssertionToken, Saml2ResponseValidatorResult> delegate = BaseOpenSamlAuthenticationProvider
.createDefaultAssertionValidator(contextDelegate);
return (token) -> delegate
.convert(new BaseOpenSamlAuthenticationProvider.AssertionToken(token.getAssertion(), token.getToken()));
}
/**
@ -503,9 +323,10 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
*/
public static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidatorWithParameters(
Consumer<Map<String, Object>> validationContextParameters) {
return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
(assertionToken) -> SAML20AssertionValidators.attributeValidator,
(assertionToken) -> createValidationContext(assertionToken, validationContextParameters));
Converter<BaseOpenSamlAuthenticationProvider.AssertionToken, Saml2ResponseValidatorResult> delegate = BaseOpenSamlAuthenticationProvider
.createDefaultAssertionValidatorWithParameters(validationContextParameters);
return (token) -> delegate
.convert(new BaseOpenSamlAuthenticationProvider.AssertionToken(token.getAssertion(), token.getToken()));
}
/**
@ -514,20 +335,10 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
* @return the default response authentication converter strategy
*/
public static Converter<ResponseToken, Saml2Authentication> createDefaultResponseAuthenticationConverter() {
return (responseToken) -> {
Response response = responseToken.response;
Saml2AuthenticationToken token = responseToken.token;
Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
String username = assertion.getSubject().getNameID().getValue();
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
List<String> sessionIndexes = getSessionIndexes(assertion);
DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal(username, attributes,
sessionIndexes);
String registrationId = responseToken.token.getRelyingPartyRegistration().getRegistrationId();
principal.setRelyingPartyRegistrationId(registrationId);
return new Saml2Authentication(principal, token.getSaml2Response(),
AuthorityUtils.createAuthorityList("ROLE_USER"));
};
Converter<BaseOpenSamlAuthenticationProvider.ResponseToken, Saml2Authentication> delegate = BaseOpenSamlAuthenticationProvider
.createDefaultResponseAuthenticationConverter();
return (token) -> delegate
.convert(new BaseOpenSamlAuthenticationProvider.ResponseToken(token.getResponse(), token.getToken()));
}
/**
@ -538,24 +349,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
*/
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
try {
Saml2AuthenticationToken token = (Saml2AuthenticationToken) authentication;
String serializedResponse = token.getSaml2Response();
Response response = parseResponse(serializedResponse);
process(token, response);
AbstractAuthenticationToken authenticationResponse = this.responseAuthenticationConverter
.convert(new ResponseToken(response, token));
if (authenticationResponse != null) {
authenticationResponse.setDetails(authentication.getDetails());
}
return authenticationResponse;
}
catch (Saml2AuthenticationException ex) {
throw ex;
}
catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, ex.getMessage(), ex);
}
return this.delegate.authenticate(authentication);
}
@Override
@ -563,337 +357,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication);
}
private Response parseResponse(String response) throws Saml2Exception, Saml2AuthenticationException {
try {
Document document = this.parserPool
.parse(new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8)));
Element element = document.getDocumentElement();
return (Response) this.responseUnmarshaller.unmarshall(element);
}
catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, ex.getMessage(), ex);
}
}
private void process(Saml2AuthenticationToken token, Response response) {
String issuer = response.getIssuer().getValue();
this.logger.debug(LogMessage.format("Processing SAML response from %s", issuer));
boolean responseSigned = response.isSigned();
ResponseToken responseToken = new ResponseToken(response, token);
Saml2ResponseValidatorResult result = this.responseSignatureValidator.convert(responseToken);
if (responseSigned) {
this.responseElementsDecrypter.accept(responseToken);
}
else if (!response.getEncryptedAssertions().isEmpty()) {
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Did not decrypt response [" + response.getID() + "] since it is not signed"));
}
result = result.concat(this.responseValidator.convert(responseToken));
boolean allAssertionsSigned = true;
for (Assertion assertion : response.getAssertions()) {
AssertionToken assertionToken = new AssertionToken(assertion, token);
result = result.concat(this.assertionSignatureValidator.convert(assertionToken));
allAssertionsSigned = allAssertionsSigned && assertion.isSigned();
if (responseSigned || assertion.isSigned()) {
this.assertionElementsDecrypter.accept(new AssertionToken(assertion, token));
}
result = result.concat(this.assertionValidator.convert(assertionToken));
}
if (!responseSigned && !allAssertionsSigned) {
String description = "Either the response or one of the assertions is unsigned. "
+ "Please either sign the response or all of the assertions.";
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, description));
}
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
if (firstAssertion != null && !hasName(firstAssertion)) {
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
"Assertion [" + firstAssertion.getID() + "] is missing a subject");
result = result.concat(error);
}
if (result.hasErrors()) {
Collection<Saml2Error> errors = result.getErrors();
if (this.logger.isTraceEnabled()) {
this.logger.trace("Found " + errors.size() + " validation errors in SAML response [" + response.getID()
+ "]: " + errors);
}
else if (this.logger.isDebugEnabled()) {
this.logger
.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]");
}
Saml2Error first = errors.iterator().next();
throw createAuthenticationException(first.getErrorCode(), first.getDescription(), null);
}
else {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Successfully processed SAML Response [" + response.getID() + "]");
}
}
}
private Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseSignatureValidator() {
return (responseToken) -> {
Response response = responseToken.getResponse();
RelyingPartyRegistration registration = responseToken.getToken().getRelyingPartyRegistration();
if (response.isSigned()) {
return OpenSamlVerificationUtils.verifySignature(response, registration).post(response.getSignature());
}
return Saml2ResponseValidatorResult.success();
};
}
private Consumer<ResponseToken> createDefaultResponseElementsDecrypter() {
return (responseToken) -> {
Response response = responseToken.getResponse();
RelyingPartyRegistration registration = responseToken.getToken().getRelyingPartyRegistration();
try {
OpenSamlDecryptionUtils.decryptResponseElements(response, registration);
}
catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
};
}
private Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionSignatureValidator() {
return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> {
RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration();
SignatureTrustEngine engine = OpenSamlVerificationUtils.trustEngine(registration);
return SAML20AssertionValidators.createSignatureValidator(engine);
}, (assertionToken) -> new ValidationContext(
Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false)));
}
private Consumer<AssertionToken> createDefaultAssertionElementsDecrypter() {
return (assertionToken) -> {
Assertion assertion = assertionToken.getAssertion();
RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration();
try {
OpenSamlDecryptionUtils.decryptAssertionElements(assertion, registration);
}
catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
};
}
private boolean hasName(Assertion assertion) {
if (assertion == null) {
return false;
}
if (assertion.getSubject() == null) {
return false;
}
if (assertion.getSubject().getNameID() == null) {
return false;
}
return assertion.getSubject().getNameID().getValue() != null;
}
private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
MultiValueMap<String, Object> attributeMap = new LinkedMultiValueMap<>();
for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
for (Attribute attribute : attributeStatement.getAttributes()) {
List<Object> attributeValues = new ArrayList<>();
for (XMLObject xmlObject : attribute.getAttributeValues()) {
Object attributeValue = getXmlObjectValue(xmlObject);
if (attributeValue != null) {
attributeValues.add(attributeValue);
}
}
attributeMap.addAll(attribute.getName(), attributeValues);
}
}
return new LinkedHashMap<>(attributeMap); // gh-11785
}
private static List<String> getSessionIndexes(Assertion assertion) {
List<String> sessionIndexes = new ArrayList<>();
for (AuthnStatement statement : assertion.getAuthnStatements()) {
sessionIndexes.add(statement.getSessionIndex());
}
return sessionIndexes;
}
private static Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject instanceof XSAny) {
return ((XSAny) xmlObject).getTextContent();
}
if (xmlObject instanceof XSString) {
return ((XSString) xmlObject).getValue();
}
if (xmlObject instanceof XSInteger) {
return ((XSInteger) xmlObject).getValue();
}
if (xmlObject instanceof XSURI) {
return ((XSURI) xmlObject).getURI();
}
if (xmlObject instanceof XSBoolean) {
XSBooleanValue xsBooleanValue = ((XSBoolean) xmlObject).getValue();
return (xsBooleanValue != null) ? xsBooleanValue.getValue() : null;
}
if (xmlObject instanceof XSDateTime) {
return ((XSDateTime) xmlObject).getValue();
}
return xmlObject;
}
private static Saml2AuthenticationException createAuthenticationException(String code, String message,
Exception cause) {
return new Saml2AuthenticationException(new Saml2Error(code, message), cause);
}
private static Converter<AssertionToken, Saml2ResponseValidatorResult> createAssertionValidator(String errorCode,
Converter<AssertionToken, SAML20AssertionValidator> validatorConverter,
Converter<AssertionToken, ValidationContext> contextConverter) {
return (assertionToken) -> {
Assertion assertion = assertionToken.assertion;
SAML20AssertionValidator validator = validatorConverter.convert(assertionToken);
ValidationContext context = contextConverter.convert(assertionToken);
try {
ValidationResult result = validator.validate(assertion, context);
if (result == ValidationResult.VALID) {
return Saml2ResponseValidatorResult.success();
}
}
catch (Exception ex) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(),
((Response) assertion.getParent()).getID(), ex.getMessage());
return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message));
}
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(),
((Response) assertion.getParent()).getID(), context.getValidationFailureMessage());
return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message));
};
}
private static ValidationContext createValidationContext(AssertionToken assertionToken,
Consumer<Map<String, Object>> paramsConsumer) {
Saml2AuthenticationToken token = assertionToken.token;
RelyingPartyRegistration relyingPartyRegistration = token.getRelyingPartyRegistration();
String audience = relyingPartyRegistration.getEntityId();
String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation();
String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyMetadata().getEntityId();
Map<String, Object> params = new HashMap<>();
Assertion assertion = assertionToken.getAssertion();
if (assertionContainsInResponseTo(assertion)) {
String requestId = getAuthnRequestId(token.getAuthenticationRequest());
params.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId);
}
params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience));
params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient));
params.put(SAML2AssertionValidationParameters.VALID_ISSUERS, Collections.singleton(assertingPartyEntityId));
paramsConsumer.accept(params);
return new ValidationContext(params);
}
private static boolean assertionContainsInResponseTo(Assertion assertion) {
if (assertion.getSubject() == null) {
return false;
}
for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) {
SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData();
if (confirmationData == null) {
continue;
}
if (StringUtils.hasText(confirmationData.getInResponseTo())) {
return true;
}
}
return false;
}
private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) {
return (serialized != null) ? serialized.getId() : null;
}
private static class SAML20AssertionValidators {
private static final Collection<ConditionValidator> conditions = new ArrayList<>();
private static final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
private static final Collection<StatementValidator> statements = new ArrayList<>();
private static final SignaturePrevalidator validator = new SAMLSignatureProfileValidator();
static {
conditions.add(new AudienceRestrictionConditionValidator());
conditions.add(new DelegationRestrictionConditionValidator());
conditions.add(new ConditionValidator() {
@Nonnull
@Override
public QName getServicedCondition() {
return OneTimeUse.DEFAULT_ELEMENT_NAME;
}
@Nonnull
@Override
public ValidationResult validate(Condition condition, Assertion assertion, ValidationContext context) {
// applications should validate their own OneTimeUse conditions
return ValidationResult.VALID;
}
});
conditions.add(new ProxyRestrictionConditionValidator());
subjects.add(new BearerSubjectConfirmationValidator() {
@Override
protected ValidationResult validateAddress(SubjectConfirmation confirmation, Assertion assertion,
ValidationContext context, boolean required) {
// applications should validate their own addresses - gh-7514
return ValidationResult.VALID;
}
});
}
private static final SAML20AssertionValidator attributeValidator = new SAML20AssertionValidator(conditions,
subjects, statements, null, null, null) {
@Nonnull
@Override
protected ValidationResult validateSignature(Assertion token, ValidationContext context) {
return ValidationResult.VALID;
}
};
static SAML20AssertionValidator createSignatureValidator(SignatureTrustEngine engine) {
return new SAML20AssertionValidator(new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), null, engine,
validator) {
@Nonnull
@Override
protected ValidationResult validateBasicData(@Nonnull Assertion assertion,
@Nonnull ValidationContext context) throws AssertionValidationException {
return ValidationResult.VALID;
}
@Nonnull
@Override
protected ValidationResult validateConditions(Assertion assertion, ValidationContext context) {
return ValidationResult.VALID;
}
@Nonnull
@Override
protected ValidationResult validateSubjectConfirmation(Assertion assertion, ValidationContext context) {
return ValidationResult.VALID;
}
@Nonnull
@Override
protected ValidationResult validateStatements(Assertion assertion, ValidationContext context) {
return ValidationResult.VALID;
}
@Override
protected ValidationResult validateIssuer(Assertion assertion, ValidationContext context) {
return ValidationResult.VALID;
}
};
}
}
/**
* A tuple containing an OpenSAML {@link Response} and its associated authentication
* token.
@ -911,6 +374,11 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
this.response = response;
}
ResponseToken(BaseOpenSamlAuthenticationProvider.ResponseToken token) {
this.token = token.getToken();
this.response = token.getResponse();
}
public Response getResponse() {
return this.response;
}
@ -938,6 +406,11 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
this.assertion = assertion;
}
AssertionToken(BaseOpenSamlAuthenticationProvider.AssertionToken token) {
this.token = token.getToken();
this.assertion = token.getAssertion();
}
public Assertion getAssertion() {
return this.assertion;
}

View File

@ -0,0 +1,617 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.xml.namespace.QName;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.XMLObjectBuilder;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.core.xml.io.Unmarshaller;
import org.opensaml.core.xml.io.UnmarshallerFactory;
import org.opensaml.core.xml.util.XMLObjectSupport;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.criterion.ProtocolCriterion;
import org.opensaml.saml.ext.saml2delrestrict.Delegate;
import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType;
import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.Condition;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedAttribute;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.LogoutRequest;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.RequestAbstractType;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.StatusResponseType;
import org.opensaml.saml.saml2.core.Subject;
import org.opensaml.saml.saml2.core.SubjectConfirmation;
import org.opensaml.saml.saml2.encryption.Decrypter;
import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver;
import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver;
import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
import org.opensaml.security.SecurityException;
import org.opensaml.security.credential.BasicCredential;
import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialResolver;
import org.opensaml.security.credential.CredentialSupport;
import org.opensaml.security.credential.UsageType;
import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion;
import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion;
import org.opensaml.security.credential.impl.CollectionCredentialResolver;
import org.opensaml.security.criteria.UsageCriterion;
import org.opensaml.security.x509.BasicX509Credential;
import org.opensaml.xmlsec.SignatureSigningParameters;
import org.opensaml.xmlsec.SignatureSigningParametersResolver;
import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap;
import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion;
import org.opensaml.xmlsec.crypto.XMLSigningUtil;
import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver;
import org.opensaml.xmlsec.encryption.support.DecryptionException;
import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver;
import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver;
import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver;
import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration;
import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver;
import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager;
import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager;
import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver;
import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory;
import org.opensaml.xmlsec.signature.SignableXMLObject;
import org.opensaml.xmlsec.signature.Signature;
import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.opensaml.xmlsec.signature.support.SignatureSupport;
import org.opensaml.xmlsec.signature.support.SignatureTrustEngine;
import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils;
/**
* For internal use only. Subject to breaking changes at any time.
*/
final class OpenSaml4Template implements OpenSamlOperations {
private static final Log logger = LogFactory.getLog(OpenSaml4Template.class);
@Override
public <T extends XMLObject> T build(QName elementName) {
XMLObjectBuilder<?> builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName);
if (builder == null) {
throw new Saml2Exception("Unable to resolve Builder for " + elementName);
}
return (T) builder.buildObject(elementName);
}
@Override
public <T extends XMLObject> T deserialize(String serialized) {
return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8)));
}
@Override
public <T extends XMLObject> T deserialize(InputStream serialized) {
try {
Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized);
Element element = document.getDocumentElement();
UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory();
Unmarshaller unmarshaller = factory.getUnmarshaller(element);
if (unmarshaller == null) {
throw new Saml2Exception("Unsupported element of type " + element.getTagName());
}
return (T) unmarshaller.unmarshall(element);
}
catch (Saml2Exception ex) {
throw ex;
}
catch (Exception ex) {
throw new Saml2Exception("Failed to deserialize payload", ex);
}
}
@Override
public OpenSaml4SerializationConfigurer serialize(XMLObject object) {
Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object);
try {
return serialize(marshaller.marshall(object));
}
catch (MarshallingException ex) {
throw new Saml2Exception(ex);
}
}
@Override
public OpenSaml4SerializationConfigurer serialize(Element element) {
return new OpenSaml4SerializationConfigurer(element);
}
@Override
public OpenSaml4SignatureConfigurer withSigningKeys(Collection<Saml2X509Credential> credentials) {
return new OpenSaml4SignatureConfigurer(credentials);
}
@Override
public OpenSaml4VerificationConfigurer withVerificationKeys(Collection<Saml2X509Credential> credentials) {
return new OpenSaml4VerificationConfigurer(credentials);
}
@Override
public OpenSaml4DecryptionConfigurer withDecryptionKeys(Collection<Saml2X509Credential> credentials) {
return new OpenSaml4DecryptionConfigurer(credentials);
}
OpenSaml4Template() {
}
static final class OpenSaml4SerializationConfigurer
implements SerializationConfigurer<OpenSaml4SerializationConfigurer> {
private final Element element;
boolean pretty;
OpenSaml4SerializationConfigurer(Element element) {
this.element = element;
}
@Override
public OpenSaml4SerializationConfigurer prettyPrint(boolean pretty) {
this.pretty = pretty;
return this;
}
@Override
public String serialize() {
if (this.pretty) {
return SerializeSupport.prettyPrintXML(this.element);
}
return SerializeSupport.nodeToString(this.element);
}
}
static final class OpenSaml4SignatureConfigurer implements SignatureConfigurer<OpenSaml4SignatureConfigurer> {
private final Collection<Saml2X509Credential> credentials;
private final Map<String, String> components = new LinkedHashMap<>();
private List<String> algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
OpenSaml4SignatureConfigurer(Collection<Saml2X509Credential> credentials) {
this.credentials = credentials;
}
@Override
public OpenSaml4SignatureConfigurer algorithms(List<String> algs) {
this.algs = algs;
return this;
}
@Override
public <O extends SignableXMLObject> O sign(O object) {
SignatureSigningParameters parameters = resolveSigningParameters();
try {
SignatureSupport.signObject(object, parameters);
}
catch (Exception ex) {
throw new Saml2Exception(ex);
}
return object;
}
@Override
public Map<String, String> sign(Map<String, String> params) {
SignatureSigningParameters parameters = resolveSigningParameters();
this.components.putAll(params);
Credential credential = parameters.getSigningCredential();
String algorithmUri = parameters.getSignatureAlgorithm();
this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri);
UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
for (Map.Entry<String, String> component : this.components.entrySet()) {
builder.queryParam(component.getKey(),
UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1));
}
String queryString = builder.build(true).toString().substring(1);
try {
byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri,
queryString.getBytes(StandardCharsets.UTF_8));
String b64Signature = Saml2Utils.samlEncode(rawSignature);
this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature);
}
catch (SecurityException ex) {
throw new Saml2Exception(ex);
}
return this.components;
}
private SignatureSigningParameters resolveSigningParameters() {
List<Credential> credentials = resolveSigningCredentials();
List<String> digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256);
String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS;
SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver();
BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration();
signingConfiguration.setSigningCredentials(credentials);
signingConfiguration.setSignatureAlgorithms(this.algs);
signingConfiguration.setSignatureReferenceDigestMethods(digests);
signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization);
signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager());
CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration));
try {
SignatureSigningParameters parameters = resolver.resolveSingle(criteria);
Assert.notNull(parameters, "Failed to resolve any signing credential");
return parameters;
}
catch (Exception ex) {
throw new Saml2Exception(ex);
}
}
private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() {
final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager();
namedManager.setUseDefaultManager(true);
final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager();
// Generator for X509Credentials
final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory();
x509Factory.setEmitEntityCertificate(true);
x509Factory.setEmitEntityCertificateChain(true);
defaultManager.registerFactory(x509Factory);
return namedManager;
}
private List<Credential> resolveSigningCredentials() {
List<Credential> credentials = new ArrayList<>();
for (Saml2X509Credential x509Credential : this.credentials) {
X509Certificate certificate = x509Credential.getCertificate();
PrivateKey privateKey = x509Credential.getPrivateKey();
BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey);
credential.setUsageType(UsageType.SIGNING);
credentials.add(credential);
}
return credentials;
}
}
static final class OpenSaml4VerificationConfigurer implements VerificationConfigurer {
private final Collection<Saml2X509Credential> credentials;
private String entityId;
OpenSaml4VerificationConfigurer(Collection<Saml2X509Credential> credentials) {
this.credentials = credentials;
}
@Override
public VerificationConfigurer entityId(String entityId) {
this.entityId = entityId;
return this;
}
private SignatureTrustEngine trustEngine(Collection<Saml2X509Credential> keys) {
Set<Credential> credentials = new HashSet<>();
for (Saml2X509Credential key : keys) {
BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
cred.setUsageType(UsageType.SIGNING);
cred.setEntityId(this.entityId);
credentials.add(cred);
}
CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
return new ExplicitKeySignatureTrustEngine(credentialsResolver,
DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver());
}
private CriteriaSet verificationCriteria(Issuer issuer) {
return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())),
new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)),
new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
}
@Override
public Collection<Saml2Error> verify(SignableXMLObject signable) {
if (signable instanceof StatusResponseType response) {
return verifySignature(response.getID(), response.getIssuer(), response.getSignature());
}
if (signable instanceof RequestAbstractType request) {
return verifySignature(request.getID(), request.getIssuer(), request.getSignature());
}
if (signable instanceof Assertion assertion) {
return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature());
}
throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName());
}
private Collection<Saml2Error> verifySignature(String id, Issuer issuer, Signature signature) {
SignatureTrustEngine trustEngine = trustEngine(this.credentials);
CriteriaSet criteria = verificationCriteria(issuer);
Collection<Saml2Error> errors = new ArrayList<>();
SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
try {
profileValidator.validate(signature);
}
catch (Exception ex) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for object [" + id + "]: "));
}
try {
if (!trustEngine.validate(signature, criteria)) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for object [" + id + "]"));
}
}
catch (Exception ex) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for object [" + id + "]: "));
}
return errors;
}
@Override
public Collection<Saml2Error> verify(RedirectParameters parameters) {
SignatureTrustEngine trustEngine = trustEngine(this.credentials);
CriteriaSet criteria = verificationCriteria(parameters.getIssuer());
if (parameters.getAlgorithm() == null) {
return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Missing signature algorithm for object [" + parameters.getId() + "]"));
}
if (!parameters.hasSignature()) {
return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Missing signature for object [" + parameters.getId() + "]"));
}
Collection<Saml2Error> errors = new ArrayList<>();
String algorithmUri = parameters.getAlgorithm();
try {
if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria,
null)) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for object [" + parameters.getId() + "]"));
}
}
catch (Exception ex) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for object [" + parameters.getId() + "]: "));
}
return errors;
}
}
static final class OpenSaml4DecryptionConfigurer implements DecryptionConfigurer {
private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(
Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(),
new SimpleRetrievalMethodEncryptedKeyResolver()));
private final Decrypter decrypter;
OpenSaml4DecryptionConfigurer(Collection<Saml2X509Credential> decryptionCredentials) {
this.decrypter = decrypter(decryptionCredentials);
}
private static Decrypter decrypter(Collection<Saml2X509Credential> decryptionCredentials) {
Collection<Credential> credentials = new ArrayList<>();
for (Saml2X509Credential key : decryptionCredentials) {
Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey());
credentials.add(cred);
}
KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials);
Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver);
decrypter.setRootInNewDocument(true);
return decrypter;
}
@Override
public void decrypt(XMLObject object) {
if (object instanceof Response response) {
decryptResponse(response);
return;
}
if (object instanceof Assertion assertion) {
decryptAssertion(assertion);
}
if (object instanceof LogoutRequest request) {
decryptLogoutRequest(request);
}
}
/*
* The methods that follow are adapted from OpenSAML's {@link DecryptAssertions},
* {@link DecryptNameIDs}, and {@link DecryptAttributes}.
*
* <p>The reason that these OpenSAML classes are not used directly is because they
* reference {@link javax.servlet.http.HttpServletRequest} which is a lower
* Servlet API version than what Spring Security SAML uses.
*
* If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then
* this arrangement can be revisited.
*/
private void decryptResponse(Response response) {
Collection<Assertion> decrypteds = new ArrayList<>();
Collection<EncryptedAssertion> encrypteds = new ArrayList<>();
int count = 0;
int size = response.getEncryptedAssertions().size();
for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) {
logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size,
response.getID()));
try {
Assertion decrypted = this.decrypter.decrypt(encrypted);
if (decrypted != null) {
encrypteds.add(encrypted);
decrypteds.add(decrypted);
}
count++;
}
catch (DecryptionException ex) {
throw new Saml2Exception(ex);
}
}
response.getEncryptedAssertions().removeAll(encrypteds);
response.getAssertions().addAll(decrypteds);
// Re-marshall the response so that any ID attributes within the decrypted
// Assertions
// will have their ID-ness re-established at the DOM level.
if (!decrypteds.isEmpty()) {
try {
XMLObjectSupport.marshall(response);
}
catch (final MarshallingException ex) {
throw new Saml2Exception(ex);
}
}
}
private void decryptAssertion(Assertion assertion) {
for (AttributeStatement statement : assertion.getAttributeStatements()) {
decryptAttributes(statement);
}
decryptSubject(assertion.getSubject());
if (assertion.getConditions() != null) {
for (Condition c : assertion.getConditions().getConditions()) {
if (!(c instanceof DelegationRestrictionType delegation)) {
continue;
}
for (Delegate d : delegation.getDelegates()) {
if (d.getEncryptedID() != null) {
try {
NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID());
if (decrypted != null) {
d.setNameID(decrypted);
d.setEncryptedID(null);
}
}
catch (DecryptionException ex) {
throw new Saml2Exception(ex);
}
}
}
}
}
}
private void decryptAttributes(AttributeStatement statement) {
Collection<Attribute> decrypteds = new ArrayList<>();
Collection<EncryptedAttribute> encrypteds = new ArrayList<>();
for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) {
try {
Attribute decrypted = this.decrypter.decrypt(encrypted);
if (decrypted != null) {
encrypteds.add(encrypted);
decrypteds.add(decrypted);
}
}
catch (Exception ex) {
throw new Saml2Exception(ex);
}
}
statement.getEncryptedAttributes().removeAll(encrypteds);
statement.getAttributes().addAll(decrypteds);
}
private void decryptSubject(Subject subject) {
if (subject != null) {
if (subject.getEncryptedID() != null) {
try {
NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID());
if (decrypted != null) {
subject.setNameID(decrypted);
subject.setEncryptedID(null);
}
}
catch (final DecryptionException ex) {
throw new Saml2Exception(ex);
}
}
for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) {
if (sc.getEncryptedID() != null) {
try {
NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID());
if (decrypted != null) {
sc.setNameID(decrypted);
sc.setEncryptedID(null);
}
}
catch (final DecryptionException ex) {
throw new Saml2Exception(ex);
}
}
}
}
}
private void decryptLogoutRequest(LogoutRequest request) {
if (request.getEncryptedID() != null) {
try {
NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID());
if (decrypted != null) {
request.setNameID(decrypted);
request.setEncryptedID(null);
}
}
catch (DecryptionException ex) {
throw new Saml2Exception(ex);
}
}
}
}
}

View File

@ -1,113 +0,0 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.authentication;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedAttribute;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.encryption.Decrypter;
import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver;
import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialSupport;
import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver;
import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver;
import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver;
import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver;
import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver;
import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
/**
* Utility methods for decrypting SAML components with OpenSAML
*
* For internal use only.
*
* @author Josh Cummings
*/
final class OpenSamlDecryptionUtils {
private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(
Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(),
new SimpleRetrievalMethodEncryptedKeyResolver()));
static void decryptResponseElements(Response response, RelyingPartyRegistration registration) {
Decrypter decrypter = decrypter(registration);
for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
try {
Assertion assertion = decrypter.decrypt(encryptedAssertion);
response.getAssertions().add(assertion);
}
catch (Exception ex) {
throw new Saml2Exception(ex);
}
}
}
static void decryptAssertionElements(Assertion assertion, RelyingPartyRegistration registration) {
Decrypter decrypter = decrypter(registration);
for (AttributeStatement statement : assertion.getAttributeStatements()) {
for (EncryptedAttribute encryptedAttribute : statement.getEncryptedAttributes()) {
try {
Attribute attribute = decrypter.decrypt(encryptedAttribute);
statement.getAttributes().add(attribute);
}
catch (Exception ex) {
throw new Saml2Exception(ex);
}
}
}
if (assertion.getSubject() == null) {
return;
}
if (assertion.getSubject().getEncryptedID() == null) {
return;
}
try {
assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
}
catch (Exception ex) {
throw new Saml2Exception(ex);
}
}
private static Decrypter decrypter(RelyingPartyRegistration registration) {
Collection<Credential> credentials = new ArrayList<>();
for (Saml2X509Credential key : registration.getDecryptionX509Credentials()) {
Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey());
credentials.add(cred);
}
KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials);
Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver);
decrypter.setRootInNewDocument(true);
return decrypter;
}
private OpenSamlDecryptionUtils() {
}
}

View File

@ -0,0 +1,184 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.authentication;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.xml.namespace.QName;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.RequestAbstractType;
import org.opensaml.saml.saml2.core.StatusResponseType;
import org.opensaml.xmlsec.signature.SignableXMLObject;
import org.w3c.dom.Element;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.web.util.UriComponentsBuilder;
interface OpenSamlOperations {
<T extends XMLObject> T build(QName elementName);
<T extends XMLObject> T deserialize(String serialized);
<T extends XMLObject> T deserialize(InputStream serialized);
SerializationConfigurer<?> serialize(XMLObject object);
SerializationConfigurer<?> serialize(Element element);
SignatureConfigurer<?> withSigningKeys(Collection<Saml2X509Credential> credentials);
VerificationConfigurer withVerificationKeys(Collection<Saml2X509Credential> credentials);
DecryptionConfigurer withDecryptionKeys(Collection<Saml2X509Credential> credentials);
interface SerializationConfigurer<B extends SerializationConfigurer<B>> {
B prettyPrint(boolean pretty);
String serialize();
}
interface SignatureConfigurer<B extends SignatureConfigurer<B>> {
B algorithms(List<String> algs);
<O extends SignableXMLObject> O sign(O object);
Map<String, String> sign(Map<String, String> params);
}
interface VerificationConfigurer {
VerificationConfigurer entityId(String entityId);
Collection<Saml2Error> verify(SignableXMLObject signable);
Collection<Saml2Error> verify(VerificationConfigurer.RedirectParameters parameters);
final class RedirectParameters {
private final String id;
private final Issuer issuer;
private final String algorithm;
private final byte[] signature;
private final byte[] content;
RedirectParameters(Map<String, String> parameters, String parametersQuery, RequestAbstractType request) {
this.id = request.getID();
this.issuer = request.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
else {
this.signature = null;
}
Map<String, String> queryParams = UriComponentsBuilder.newInstance()
.query(parametersQuery)
.build(true)
.getQueryParams()
.toSingleValueMap();
String relayState = parameters.get(Saml2ParameterNames.RELAY_STATE);
this.content = getContent(Saml2ParameterNames.SAML_REQUEST, relayState, queryParams);
}
RedirectParameters(Map<String, String> parameters, String parametersQuery, StatusResponseType response) {
this.id = response.getID();
this.issuer = response.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
else {
this.signature = null;
}
Map<String, String> queryParams = UriComponentsBuilder.newInstance()
.query(parametersQuery)
.build(true)
.getQueryParams()
.toSingleValueMap();
String relayState = parameters.get(Saml2ParameterNames.RELAY_STATE);
this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams);
}
static byte[] getContent(String samlObject, String relayState, final Map<String, String> queryParams) {
if (Objects.nonNull(relayState)) {
return String
.format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject),
Saml2ParameterNames.RELAY_STATE, queryParams.get(Saml2ParameterNames.RELAY_STATE),
Saml2ParameterNames.SIG_ALG, queryParams.get(Saml2ParameterNames.SIG_ALG))
.getBytes(StandardCharsets.UTF_8);
}
else {
return String
.format("%s=%s&%s=%s", samlObject, queryParams.get(samlObject), Saml2ParameterNames.SIG_ALG,
queryParams.get(Saml2ParameterNames.SIG_ALG))
.getBytes(StandardCharsets.UTF_8);
}
}
String getId() {
return this.id;
}
Issuer getIssuer() {
return this.issuer;
}
byte[] getContent() {
return this.content;
}
String getAlgorithm() {
return this.algorithm;
}
byte[] getSignature() {
return this.signature;
}
boolean hasSignature() {
return this.signature != null;
}
}
}
interface DecryptionConfigurer {
void decrypt(XMLObject object);
}
}

View File

@ -1,194 +0,0 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.authentication;
import java.nio.charset.StandardCharsets;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver;
import org.opensaml.security.SecurityException;
import org.opensaml.security.credential.BasicCredential;
import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialSupport;
import org.opensaml.security.credential.UsageType;
import org.opensaml.xmlsec.SignatureSigningParameters;
import org.opensaml.xmlsec.SignatureSigningParametersResolver;
import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion;
import org.opensaml.xmlsec.crypto.XMLSigningUtil;
import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration;
import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager;
import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager;
import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory;
import org.opensaml.xmlsec.signature.SignableXMLObject;
import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.opensaml.xmlsec.signature.support.SignatureSupport;
import org.w3c.dom.Element;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils;
/**
* Utility methods for signing SAML components with OpenSAML
*
* For internal use only.
*
* @author Josh Cummings
*/
final class OpenSamlSigningUtils {
static String serialize(XMLObject object) {
try {
Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object);
Element element = marshaller.marshall(object);
return SerializeSupport.nodeToString(element);
}
catch (MarshallingException ex) {
throw new Saml2Exception(ex);
}
}
static <O extends SignableXMLObject> O sign(O object, RelyingPartyRegistration relyingPartyRegistration) {
SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
try {
SignatureSupport.signObject(object, parameters);
return object;
}
catch (Exception ex) {
throw new Saml2Exception(ex);
}
}
static QueryParametersPartial sign(RelyingPartyRegistration registration) {
return new QueryParametersPartial(registration);
}
private static SignatureSigningParameters resolveSigningParameters(
RelyingPartyRegistration relyingPartyRegistration) {
List<Credential> credentials = resolveSigningCredentials(relyingPartyRegistration);
List<String> algorithms = relyingPartyRegistration.getAssertingPartyMetadata().getSigningAlgorithms();
List<String> digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256);
String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS;
SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver();
CriteriaSet criteria = new CriteriaSet();
BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration();
signingConfiguration.setSigningCredentials(credentials);
signingConfiguration.setSignatureAlgorithms(algorithms);
signingConfiguration.setSignatureReferenceDigestMethods(digests);
signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization);
signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager());
criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration));
try {
SignatureSigningParameters parameters = resolver.resolveSingle(criteria);
Assert.notNull(parameters, "Failed to resolve any signing credential");
return parameters;
}
catch (Exception ex) {
throw new Saml2Exception(ex);
}
}
private static NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() {
final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager();
namedManager.setUseDefaultManager(true);
final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager();
// Generator for X509Credentials
final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory();
x509Factory.setEmitEntityCertificate(true);
x509Factory.setEmitEntityCertificateChain(true);
defaultManager.registerFactory(x509Factory);
return namedManager;
}
private static List<Credential> resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) {
List<Credential> credentials = new ArrayList<>();
for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) {
X509Certificate certificate = x509Credential.getCertificate();
PrivateKey privateKey = x509Credential.getPrivateKey();
BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey);
credential.setEntityId(relyingPartyRegistration.getEntityId());
credential.setUsageType(UsageType.SIGNING);
credentials.add(credential);
}
return credentials;
}
private OpenSamlSigningUtils() {
}
static class QueryParametersPartial {
final RelyingPartyRegistration registration;
final Map<String, String> components = new LinkedHashMap<>();
QueryParametersPartial(RelyingPartyRegistration registration) {
this.registration = registration;
}
QueryParametersPartial param(String key, String value) {
this.components.put(key, value);
return this;
}
Map<String, String> parameters() {
SignatureSigningParameters parameters = resolveSigningParameters(this.registration);
Credential credential = parameters.getSigningCredential();
String algorithmUri = parameters.getSignatureAlgorithm();
this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri);
UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
for (Map.Entry<String, String> component : this.components.entrySet()) {
builder.queryParam(component.getKey(),
UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1));
}
String queryString = builder.build(true).toString().substring(1);
try {
byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri,
queryString.getBytes(StandardCharsets.UTF_8));
String b64Signature = Saml2Utils.samlEncode(rawSignature);
this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature);
}
catch (SecurityException ex) {
throw new Saml2Exception(ex);
}
return this.components;
}
}
}

View File

@ -1,222 +0,0 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.authentication;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import jakarta.servlet.http.HttpServletRequest;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.criterion.ProtocolCriterion;
import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.RequestAbstractType;
import org.opensaml.saml.saml2.core.StatusResponseType;
import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialResolver;
import org.opensaml.security.credential.UsageType;
import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion;
import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion;
import org.opensaml.security.credential.impl.CollectionCredentialResolver;
import org.opensaml.security.criteria.UsageCriterion;
import org.opensaml.security.x509.BasicX509Credential;
import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap;
import org.opensaml.xmlsec.signature.Signature;
import org.opensaml.xmlsec.signature.support.SignatureTrustEngine;
import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.web.util.UriUtils;
/**
* Utility methods for verifying SAML component signatures with OpenSAML
*
* For internal use only.
*
* @author Josh Cummings
*/
final class OpenSamlVerificationUtils {
static VerifierPartial verifySignature(StatusResponseType object, RelyingPartyRegistration registration) {
return new VerifierPartial(object, registration);
}
static VerifierPartial verifySignature(RequestAbstractType object, RelyingPartyRegistration registration) {
return new VerifierPartial(object, registration);
}
static SignatureTrustEngine trustEngine(RelyingPartyRegistration registration) {
Set<Credential> credentials = new HashSet<>();
Collection<Saml2X509Credential> keys = registration.getAssertingPartyMetadata()
.getVerificationX509Credentials();
for (Saml2X509Credential key : keys) {
BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
cred.setUsageType(UsageType.SIGNING);
cred.setEntityId(registration.getAssertingPartyMetadata().getEntityId());
credentials.add(cred);
}
CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
return new ExplicitKeySignatureTrustEngine(credentialsResolver,
DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver());
}
private OpenSamlVerificationUtils() {
}
static class VerifierPartial {
private final String id;
private final CriteriaSet criteria;
private final SignatureTrustEngine trustEngine;
VerifierPartial(StatusResponseType object, RelyingPartyRegistration registration) {
this.id = object.getID();
this.criteria = verificationCriteria(object.getIssuer());
this.trustEngine = trustEngine(registration);
}
VerifierPartial(RequestAbstractType object, RelyingPartyRegistration registration) {
this.id = object.getID();
this.criteria = verificationCriteria(object.getIssuer());
this.trustEngine = trustEngine(registration);
}
Saml2ResponseValidatorResult redirect(HttpServletRequest request, String objectParameterName) {
RedirectSignature signature = new RedirectSignature(request, objectParameterName);
if (signature.getAlgorithm() == null) {
return Saml2ResponseValidatorResult.failure(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Missing signature algorithm for object [" + this.id + "]"));
}
if (!signature.hasSignature()) {
return Saml2ResponseValidatorResult.failure(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Missing signature for object [" + this.id + "]"));
}
Collection<Saml2Error> errors = new ArrayList<>();
String algorithmUri = signature.getAlgorithm();
try {
if (!this.trustEngine.validate(signature.getSignature(), signature.getContent(), algorithmUri,
this.criteria, null)) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for object [" + this.id + "]"));
}
}
catch (Exception ex) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for object [" + this.id + "]: "));
}
return Saml2ResponseValidatorResult.failure(errors);
}
Saml2ResponseValidatorResult post(Signature signature) {
Collection<Saml2Error> errors = new ArrayList<>();
SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
try {
profileValidator.validate(signature);
}
catch (Exception ex) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for object [" + this.id + "]: "));
}
try {
if (!this.trustEngine.validate(signature, this.criteria)) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for object [" + this.id + "]"));
}
}
catch (Exception ex) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for object [" + this.id + "]: "));
}
return Saml2ResponseValidatorResult.failure(errors);
}
private CriteriaSet verificationCriteria(Issuer issuer) {
CriteriaSet criteria = new CriteriaSet();
criteria.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())));
criteria.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
criteria.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
return criteria;
}
private static class RedirectSignature {
private final HttpServletRequest request;
private final String objectParameterName;
RedirectSignature(HttpServletRequest request, String objectParameterName) {
this.request = request;
this.objectParameterName = objectParameterName;
}
String getAlgorithm() {
return this.request.getParameter(Saml2ParameterNames.SIG_ALG);
}
byte[] getContent() {
if (this.request.getParameter(Saml2ParameterNames.RELAY_STATE) != null) {
return String
.format("%s=%s&%s=%s&%s=%s", this.objectParameterName, UriUtils
.encode(this.request.getParameter(this.objectParameterName), StandardCharsets.ISO_8859_1),
Saml2ParameterNames.RELAY_STATE,
UriUtils.encode(this.request.getParameter(Saml2ParameterNames.RELAY_STATE),
StandardCharsets.ISO_8859_1),
Saml2ParameterNames.SIG_ALG,
UriUtils.encode(getAlgorithm(), StandardCharsets.ISO_8859_1))
.getBytes(StandardCharsets.UTF_8);
}
else {
return String
.format("%s=%s&%s=%s", this.objectParameterName,
UriUtils.encode(this.request.getParameter(this.objectParameterName),
StandardCharsets.ISO_8859_1),
Saml2ParameterNames.SIG_ALG,
UriUtils.encode(getAlgorithm(), StandardCharsets.ISO_8859_1))
.getBytes(StandardCharsets.UTF_8);
}
}
byte[] getSignature() {
return Saml2Utils.samlDecode(this.request.getParameter(Saml2ParameterNames.SIGNATURE));
}
boolean hasSignature() {
return this.request.getParameter(Saml2ParameterNames.SIGNATURE) != null;
}
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Base64;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;
@ -28,7 +29,11 @@ import java.util.zip.InflaterOutputStream;
import org.springframework.security.saml2.Saml2Exception;
/**
* @since 5.3
* Utility methods for working with serialized SAML messages.
*
* For internal use only.
*
* @author Josh Cummings
*/
final class Saml2Utils {
@ -69,4 +74,123 @@ final class Saml2Utils {
}
}
static EncodingConfigurer withDecoded(String decoded) {
return new EncodingConfigurer(decoded);
}
static DecodingConfigurer withEncoded(String encoded) {
return new DecodingConfigurer(encoded);
}
static final class EncodingConfigurer {
private final String decoded;
private boolean deflate;
private EncodingConfigurer(String decoded) {
this.decoded = decoded;
}
EncodingConfigurer deflate(boolean deflate) {
this.deflate = deflate;
return this;
}
String encode() {
byte[] bytes = (this.deflate) ? Saml2Utils.samlDeflate(this.decoded)
: this.decoded.getBytes(StandardCharsets.UTF_8);
return Saml2Utils.samlEncode(bytes);
}
}
static final class DecodingConfigurer {
private static final Base64Checker BASE_64_CHECKER = new Base64Checker();
private final String encoded;
private boolean inflate;
private boolean requireBase64;
private DecodingConfigurer(String encoded) {
this.encoded = encoded;
}
DecodingConfigurer inflate(boolean inflate) {
this.inflate = inflate;
return this;
}
DecodingConfigurer requireBase64(boolean requireBase64) {
this.requireBase64 = requireBase64;
return this;
}
String decode() {
if (this.requireBase64) {
BASE_64_CHECKER.checkAcceptable(this.encoded);
}
byte[] bytes = Saml2Utils.samlDecode(this.encoded);
return (this.inflate) ? Saml2Utils.samlInflate(bytes) : new String(bytes, StandardCharsets.UTF_8);
}
static class Base64Checker {
private static final int[] values = genValueMapping();
Base64Checker() {
}
private static int[] genValueMapping() {
byte[] alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
.getBytes(StandardCharsets.ISO_8859_1);
int[] values = new int[256];
Arrays.fill(values, -1);
for (int i = 0; i < alphabet.length; i++) {
values[alphabet[i] & 0xff] = i;
}
return values;
}
boolean isAcceptable(String s) {
int goodChars = 0;
int lastGoodCharVal = -1;
// count number of characters from Base64 alphabet
for (int i = 0; i < s.length(); i++) {
int val = values[0xff & s.charAt(i)];
if (val != -1) {
lastGoodCharVal = val;
goodChars++;
}
}
// in cases of an incomplete final chunk, ensure the unused bits are zero
switch (goodChars % 4) {
case 0:
return true;
case 2:
return (lastGoodCharVal & 0b1111) == 0;
case 3:
return (lastGoodCharVal & 0b11) == 0;
default:
return false;
}
}
void checkAcceptable(String ins) {
if (!isAcceptable(ins)) {
throw new IllegalArgumentException("Failed to decode SAMLResponse");
}
}
}
}
}

View File

@ -32,12 +32,9 @@ import java.util.function.Consumer;
import javax.xml.namespace.QName;
import com.fasterxml.jackson.databind.ObjectMapper;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.junit.jupiter.api.Test;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.core.xml.schema.XSDateTime;
import org.opensaml.core.xml.schema.impl.XSDateTimeBuilder;
import org.opensaml.saml.common.SignableSAMLObject;
@ -68,12 +65,10 @@ import org.opensaml.saml.saml2.core.impl.StatusBuilder;
import org.opensaml.saml.saml2.core.impl.StatusCodeBuilder;
import org.opensaml.xmlsec.encryption.impl.EncryptedDataBuilder;
import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.w3c.dom.Element;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.Authentication;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
@ -107,6 +102,8 @@ public class OpenSaml4AuthenticationProviderTests {
private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp";
private final OpenSamlOperations saml = new OpenSaml4Template();
private OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider();
private Saml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("name",
@ -839,14 +836,7 @@ public class OpenSaml4AuthenticationProviderTests {
}
private String serialize(XMLObject object) {
try {
Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object);
Element element = marshaller.marshall(object);
return SerializeSupport.nodeToString(element);
}
catch (MarshallingException ex) {
throw new Saml2Exception(ex);
}
return this.saml.serialize(object).serialize();
}
private Consumer<Saml2AuthenticationException> errorOf(String errorCode) {

View File

@ -1,58 +0,0 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.authentication;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.xmlsec.signature.Signature;
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Test open SAML signatures
*/
public class OpenSamlSigningUtilsTests {
private RelyingPartyRegistration registration;
@BeforeEach
public void setup() {
this.registration = RelyingPartyRegistration.withRegistrationId("saml-idp")
.entityId("https://some.idp.example.com/entity-id")
.signingX509Credentials((c) -> {
c.add(TestSaml2X509Credentials.relyingPartySigningCredential());
c.add(TestSaml2X509Credentials.assertingPartySigningCredential());
})
.assertingPartyDetails((c) -> c.entityId("https://some.idp.example.com/entity-id")
.singleSignOnServiceLocation("https://some.idp.example.com/service-location"))
.build();
}
@Test
public void whenSigningAnObjectThenKeyInfoIsPartOfTheSignature() throws Exception {
Response response = TestOpenSamlObjects.response();
OpenSamlSigningUtils.sign(response, this.registration);
Signature signature = response.getSignature();
assertThat(signature).isNotNull();
assertThat(signature.getKeyInfo()).isNotNull();
}
}