Polish OpenSamlAuthenticationProvider

- Use type-safe CriteriaSet
- Keep Assertion immutable

Closes gh-8471
This commit is contained in:
Josh Cummings 2020-03-31 15:55:32 -06:00 committed by Eleftheria Stein
parent 7748fb00ba
commit d4dbe069ad
1 changed files with 262 additions and 241 deletions

View File

@ -15,12 +15,28 @@
*/ */
package org.springframework.security.saml2.provider.service.authentication; package org.springframework.security.saml2.provider.service.authentication;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nonnull;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.opensaml.saml.common.SignableSAMLObject; import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.saml.common.assertion.AssertionValidationException;
import org.opensaml.saml.common.assertion.ValidationContext; import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.common.assertion.ValidationResult; import org.opensaml.saml.common.assertion.ValidationResult;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.criterion.ProtocolCriterion;
import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion;
import org.opensaml.saml.saml2.assertion.ConditionValidator; import org.opensaml.saml.saml2.assertion.ConditionValidator;
import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator; import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator;
import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters; import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters;
@ -40,16 +56,20 @@ import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
import org.opensaml.security.credential.Credential; import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialResolver; import org.opensaml.security.credential.CredentialResolver;
import org.opensaml.security.credential.CredentialSupport; import org.opensaml.security.credential.CredentialSupport;
import org.opensaml.security.credential.UsageType;
import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion;
import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion;
import org.opensaml.security.credential.impl.CollectionCredentialResolver; import org.opensaml.security.credential.impl.CollectionCredentialResolver;
import org.opensaml.security.criteria.UsageCriterion;
import org.opensaml.security.x509.BasicX509Credential;
import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap;
import org.opensaml.xmlsec.encryption.support.DecryptionException; import org.opensaml.xmlsec.encryption.support.DecryptionException;
import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver;
import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver; import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver;
import org.opensaml.xmlsec.signature.support.SignatureException;
import org.opensaml.xmlsec.signature.support.SignaturePrevalidator; import org.opensaml.xmlsec.signature.support.SignaturePrevalidator;
import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; import org.opensaml.xmlsec.signature.support.SignatureTrustEngine;
import org.opensaml.xmlsec.signature.support.SignatureValidator;
import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -60,30 +80,24 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.credentials.Saml2X509Credential; import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static java.lang.String.format;
import static java.util.Collections.singleton; import static java.util.Collections.singleton;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.CLOCK_SKEW;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.COND_VALID_AUDIENCES;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.DECRYPTION_ERROR; import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.DECRYPTION_ERROR;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INVALID_ASSERTION;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INVALID_DESTINATION; import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INVALID_DESTINATION;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INVALID_ISSUER; import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INVALID_ISSUER;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INVALID_SIGNATURE;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA; import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.SUBJECT_NOT_FOUND; import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.SUBJECT_NOT_FOUND;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS; import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.USERNAME_NOT_FOUND; import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.USERNAME_NOT_FOUND;
import static org.springframework.util.Assert.notNull; import static org.springframework.util.Assert.notNull;
import static org.springframework.util.StringUtils.hasText;
/** /**
* Implementation of {@link AuthenticationProvider} for SAML authentications when receiving a * Implementation of {@link AuthenticationProvider} for SAML authentications when receiving a
@ -125,6 +139,20 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private static Log logger = LogFactory.getLog(OpenSamlAuthenticationProvider.class); private static Log logger = LogFactory.getLog(OpenSamlAuthenticationProvider.class);
private final List<ConditionValidator> conditions = Collections.singletonList(new AudienceRestrictionConditionValidator());
private final SubjectConfirmationValidator subjectConfirmationValidator = new BearerSubjectConfirmationValidator() {
@Nonnull
@Override
protected ValidationResult validateAddress(@Nonnull SubjectConfirmation confirmation,
@Nonnull Assertion assertion, @Nonnull ValidationContext context) {
// skipping address validation - gh-7514
return ValidationResult.VALID;
}
};
private final List<SubjectConfirmationValidator> subjects = Collections.singletonList(this.subjectConfirmationValidator);
private final List<StatementValidator> statements = Collections.emptyList();
private final SignaturePrevalidator signaturePrevalidator = new SAMLSignatureProfileValidator();
private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance(); private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
private Converter<Assertion, Collection<? extends GrantedAuthority>> authoritiesExtractor = private Converter<Assertion, Collection<? extends GrantedAuthority>> authoritiesExtractor =
(a -> singletonList(new SimpleGrantedAuthority("ROLE_USER"))); (a -> singletonList(new SimpleGrantedAuthority("ROLE_USER")));
@ -173,17 +201,17 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public Authentication authenticate(Authentication authentication) throws AuthenticationException {
try { try {
Saml2AuthenticationToken token = (Saml2AuthenticationToken) authentication; Saml2AuthenticationToken token = (Saml2AuthenticationToken) authentication;
Response samlResponse = getSaml2Response(token); Response response = parse(token.getSaml2Response());
Assertion assertion = validateSaml2Response(token, token.getRecipientUri(), samlResponse); List<Assertion> validAssertions = validateResponse(token, response);
Assertion assertion = validAssertions.get(0);
String username = getUsername(token, assertion); String username = getUsername(token, assertion);
return new Saml2Authentication( return new Saml2Authentication(
new SimpleSaml2AuthenticatedPrincipal(username), token.getSaml2Response(), new SimpleSaml2AuthenticatedPrincipal(username), token.getSaml2Response(),
this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)) this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
);
} catch (Saml2AuthenticationException e) { } catch (Saml2AuthenticationException e) {
throw e; throw e;
} catch (Exception e) { } catch (Exception e) {
throw authException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, e.getMessage(), e); throw authException(INTERNAL_VALIDATION_ERROR, e.getMessage(), e);
} }
} }
@ -199,167 +227,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return this.authoritiesExtractor.convert(assertion); return this.authoritiesExtractor.convert(assertion);
} }
private String getUsername(Saml2AuthenticationToken token, Assertion assertion) throws Saml2AuthenticationException { private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException {
String username = null;
Subject subject = assertion.getSubject();
if (subject == null) {
throw authException(SUBJECT_NOT_FOUND, "Assertion [" + assertion.getID() + "] is missing a subject");
}
if (subject.getNameID() != null) {
username = subject.getNameID().getValue();
}
else if (subject.getEncryptedID() != null) {
NameID nameId = decrypt(token, subject.getEncryptedID());
username = nameId.getValue();
}
if (username == null) {
throw authException(USERNAME_NOT_FOUND, "Assertion [" + assertion.getID() + "] is missing a user identifier");
}
return username;
}
private Assertion validateSaml2Response(Saml2AuthenticationToken token,
String recipient,
Response samlResponse) throws Saml2AuthenticationException {
//optional validation if the response contains a destination
if (hasText(samlResponse.getDestination()) && !recipient.equals(samlResponse.getDestination())) {
throw authException(INVALID_DESTINATION, "Invalid SAML response destination: " + samlResponse.getDestination());
}
String issuer = samlResponse.getIssuer().getValue();
if (logger.isDebugEnabled()) {
logger.debug("Validating SAML response from " + issuer);
}
if (!hasText(issuer) || (!issuer.equals(token.getIdpEntityId()))) {
String message = String.format("Response issuer '%s' doesn't match '%s'", issuer, token.getIdpEntityId());
throw authException(INVALID_ISSUER, message);
}
Saml2AuthenticationException lastValidationError = null;
boolean responseSigned = hasValidSignature(samlResponse, token);
for (Assertion a : samlResponse.getAssertions()) {
if (logger.isDebugEnabled()) {
logger.debug("Checking plain assertion validity " + a);
}
try { try {
validateAssertion(recipient, a, token, !responseSigned); Object result = this.saml.resolve(response);
return a;
} catch (Saml2AuthenticationException e) {
lastValidationError = e;
}
}
for (EncryptedAssertion ea : samlResponse.getEncryptedAssertions()) {
if (logger.isDebugEnabled()) {
logger.debug("Checking encrypted assertion validity " + ea);
}
try {
Assertion a = decrypt(token, ea);
validateAssertion(recipient, a, token, !responseSigned);
return a;
} catch (Saml2AuthenticationException e) {
lastValidationError = e;
}
}
if (lastValidationError != null) {
throw lastValidationError;
}
else {
throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
}
}
private boolean hasValidSignature(SignableSAMLObject samlObject, Saml2AuthenticationToken token) {
if (!samlObject.isSigned()) {
if (logger.isDebugEnabled()) {
logger.debug("SAML object is not signed, no signatures found");
}
return false;
}
List<X509Certificate> verificationKeys = getVerificationCertificates(token);
if (verificationKeys.isEmpty()) {
return false;
}
for (X509Certificate certificate : verificationKeys) {
Credential credential = getVerificationCredential(certificate);
try {
SignatureValidator.validate(samlObject.getSignature(), credential);
if (logger.isDebugEnabled()) {
logger.debug("Valid signature found in SAML object:"+samlObject.getClass().getName());
}
return true;
}
catch (SignatureException ignored) {
if (logger.isTraceEnabled()) {
logger.trace("Signature validation failed with cert:"+certificate.toString(), ignored);
}
else if (logger.isDebugEnabled()) {
logger.debug("Signature validation failed with cert:"+certificate.toString());
}
}
}
return false;
}
private void validateAssertion(String recipient, Assertion a, Saml2AuthenticationToken token, boolean signatureRequired) {
SAML20AssertionValidator validator = getAssertionValidator(token);
Map<String, Object> validationParams = new HashMap<>();
validationParams.put(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false);
validationParams.put(
SAML2AssertionValidationParameters.CLOCK_SKEW,
this.responseTimeValidationSkew.toMillis()
);
validationParams.put(
SAML2AssertionValidationParameters.COND_VALID_AUDIENCES,
singleton(token.getLocalSpEntityId())
);
if (hasText(recipient)) {
validationParams.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, singleton(recipient));
}
if (signatureRequired && !hasValidSignature(a, token)) {
if (logger.isDebugEnabled()) {
logger.debug(format("Assertion [%s] does not a valid signature.", a.getID()));
}
throw authException(Saml2ErrorCodes.INVALID_SIGNATURE, "Assertion doesn't have a valid signature.");
}
//ensure that OpenSAML doesn't attempt signature validation, already performed
a.setSignature(null);
//ensure that we don't validate IP addresses as part of our validation gh-7514
if (a.getSubject() != null) {
for (SubjectConfirmation sc : a.getSubject().getSubjectConfirmations()) {
if (sc.getSubjectConfirmationData() != null) {
sc.getSubjectConfirmationData().setAddress(null);
}
}
}
//remainder of assertion validation
ValidationContext vctx = new ValidationContext(validationParams);
try {
ValidationResult result = validator.validate(a, vctx);
boolean valid = result.equals(ValidationResult.VALID);
if (!valid) {
if (logger.isDebugEnabled()) {
logger.debug(format("Failed to validate assertion from %s", token.getIdpEntityId()));
}
throw authException(Saml2ErrorCodes.INVALID_ASSERTION, vctx.getValidationFailureMessage());
}
}
catch (AssertionValidationException e) {
if (logger.isDebugEnabled()) {
logger.debug("Failed to validate assertion:", e);
}
throw authException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, e.getMessage(), e);
}
}
private Response getSaml2Response(Saml2AuthenticationToken token) throws Saml2Exception, Saml2AuthenticationException {
try {
Object result = this.saml.resolve(token.getSaml2Response());
if (result instanceof Response) { if (result instanceof Response) {
return (Response) result; return (Response) result;
} }
@ -372,68 +242,172 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
private Saml2Error validationError(String code, String description) { private List<Assertion> validateResponse(Saml2AuthenticationToken token, Response response)
return new Saml2Error( throws Saml2AuthenticationException {
code,
description List<Assertion> validAssertions = new ArrayList<>();
); String issuer = response.getIssuer().getValue();
if (logger.isDebugEnabled()) {
logger.debug("Validating SAML response from " + issuer);
} }
private Saml2AuthenticationException authException(String code, String description) throws Saml2AuthenticationException { List<Assertion> assertions = new ArrayList<>(response.getAssertions());
return new Saml2AuthenticationException( for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
validationError(code, description) Assertion assertion = decrypt(token, encryptedAssertion);
); assertions.add(assertion);
}
if (assertions.isEmpty()) {
throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
} }
if (!isSigned(response, assertions)) {
private Saml2AuthenticationException authException(String code, String description, Exception cause) throws Saml2AuthenticationException { throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " +
return new Saml2AuthenticationException( "Please either sign the response or all of the assertions.");
validationError(code, description),
cause
);
} }
private SAML20AssertionValidator getAssertionValidator(Saml2AuthenticationToken provider) { SignatureTrustEngine signatureTrustEngine = buildSignatureTrustEngine(token);
List<ConditionValidator> conditions = Collections.singletonList(new AudienceRestrictionConditionValidator());
BearerSubjectConfirmationValidator subjectConfirmationValidator = new BearerSubjectConfirmationValidator();
List<SubjectConfirmationValidator> subjects = Collections.singletonList(subjectConfirmationValidator); Map<String, Saml2AuthenticationException> validationExceptions = new HashMap<>();
List<StatementValidator> statements = Collections.emptyList(); if (response.isSigned()) {
SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
try {
profileValidator.validate(response.getSignature());
} catch (Exception e) {
validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]", e));
}
try {
CriteriaSet criteriaSet = new CriteriaSet();
criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer)));
criteriaSet.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
if (!signatureTrustEngine.validate(response.getSignature(), criteriaSet)) {
validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]"));
}
} catch (Exception e) {
validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]", e));
}
}
String destination = response.getDestination();
if (StringUtils.hasText(destination) && !destination.equals(token.getRecipientUri())) {
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message));
}
if (!StringUtils.hasText(issuer) || !issuer.equals(token.getIdpEntityId())) {
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message));
}
SAML20AssertionValidator validator = buildSamlAssertionValidator(signatureTrustEngine);
ValidationContext context = buildValidationContext(token, response);
if (logger.isDebugEnabled()) {
logger.debug("Validating " + assertions.size() + " assertions");
}
for (Assertion assertion : assertions) {
if (logger.isTraceEnabled()) {
logger.trace("Validating assertion " + assertion.getID());
}
try {
validAssertions.add(validateAssertion(assertion, validator, context));
} catch (Exception e) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]", assertion.getID(), response.getID());
validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message, e));
}
}
if (validationExceptions.isEmpty()) {
if (logger.isDebugEnabled()) {
logger.debug("Successfully validated SAML Response [" + response.getID() + "]");
}
} else {
if (logger.isTraceEnabled()) {
logger.debug("Found " + validationExceptions.size() + " validation errors in SAML response [" + response.getID() + "]: " +
validationExceptions.values());
} else if (logger.isDebugEnabled()) {
logger.debug("Found " + validationExceptions.size() + " validation errors in SAML response [" + response.getID() + "]");
}
}
if (!validationExceptions.isEmpty()) {
throw validationExceptions.values().iterator().next();
}
if (validAssertions.isEmpty()) {
throw authException(MALFORMED_RESPONSE_DATA, "No valid assertions found in response.");
}
return validAssertions;
}
private boolean isSigned(Response samlResponse, List<Assertion> assertions) {
if (samlResponse.isSigned()) {
return true;
}
for (Assertion assertion : assertions) {
if (!assertion.isSigned()) {
return false;
}
}
return true;
}
private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) {
Set<Credential> credentials = new HashSet<>(); Set<Credential> credentials = new HashSet<>();
for (X509Certificate key : getVerificationCertificates(provider)) { for (X509Certificate key : getVerificationCertificates(token)) {
Credential cred = getVerificationCredential(key); BasicX509Credential cred = new BasicX509Credential(key);
cred.setUsageType(UsageType.SIGNING);
cred.setEntityId(token.getIdpEntityId());
credentials.add(cred); credentials.add(cred);
} }
CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
SignatureTrustEngine signatureTrustEngine = new ExplicitKeySignatureTrustEngine( return new ExplicitKeySignatureTrustEngine(
credentialsResolver, credentialsResolver,
DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver() DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()
); );
SignaturePrevalidator signaturePrevalidator = new SAMLSignatureProfileValidator(); }
private ValidationContext buildValidationContext(Saml2AuthenticationToken token, Response response) {
Map<String, Object> validationParams = new HashMap<>();
validationParams.put(SIGNATURE_REQUIRED, !response.isSigned());
validationParams.put(CLOCK_SKEW, this.responseTimeValidationSkew.toMillis());
validationParams.put(COND_VALID_AUDIENCES, singleton(token.getLocalSpEntityId()));
if (StringUtils.hasText(token.getRecipientUri())) {
validationParams.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, singleton(token.getRecipientUri()));
}
return new ValidationContext(validationParams);
}
private SAML20AssertionValidator buildSamlAssertionValidator(SignatureTrustEngine signatureTrustEngine) {
return new SAML20AssertionValidator( return new SAML20AssertionValidator(
conditions, this.conditions, this.subjects, this.statements, signatureTrustEngine, this.signaturePrevalidator);
subjects,
statements,
signatureTrustEngine,
signaturePrevalidator
);
} }
private Credential getVerificationCredential(X509Certificate certificate) { private Assertion validateAssertion(Assertion assertion,
return CredentialSupport.getSimpleCredential(certificate, null); SAML20AssertionValidator validator, ValidationContext context) {
}
private Decrypter getDecrypter(Saml2X509Credential key) { ValidationResult result;
Credential credential = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); try {
KeyInfoCredentialResolver resolver = new StaticKeyInfoCredentialResolver(credential); result = validator.validate(assertion, context);
Decrypter decrypter = new Decrypter(null, resolver, this.saml.getEncryptedKeyResolver()); } catch (Exception e) {
decrypter.setRootInNewDocument(true); throw new Saml2Exception("An error occurred while validation the assertion", e);
return decrypter; }
if (result != ValidationResult.VALID) {
throw new Saml2Exception("An error occurred while validating the assertion: " +
context.getValidationFailureMessage());
}
return assertion;
} }
private Assertion decrypt(Saml2AuthenticationToken token, EncryptedAssertion assertion) private Assertion decrypt(Saml2AuthenticationToken token, EncryptedAssertion assertion)
throws Saml2AuthenticationException { throws Saml2AuthenticationException {
Saml2AuthenticationException last = null; Saml2AuthenticationException last = null;
List<Saml2X509Credential> decryptionCredentials = getDecryptionCredentials(token); List<Saml2X509Credential> decryptionCredentials = getDecryptionCredentials(token);
if (decryptionCredentials.isEmpty()) { if (decryptionCredentials.isEmpty()) {
@ -451,22 +425,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
throw last; throw last;
} }
private NameID decrypt(Saml2AuthenticationToken token, EncryptedID assertion) throws Saml2AuthenticationException { private Decrypter getDecrypter(Saml2X509Credential key) {
Saml2AuthenticationException last = null; Credential credential = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey());
List<Saml2X509Credential> decryptionCredentials = getDecryptionCredentials(token); KeyInfoCredentialResolver resolver = new StaticKeyInfoCredentialResolver(credential);
if (decryptionCredentials.isEmpty()) { Decrypter decrypter = new Decrypter(null, resolver, this.saml.getEncryptedKeyResolver());
throw authException(DECRYPTION_ERROR, "No valid decryption credentials found."); decrypter.setRootInNewDocument(true);
} return decrypter;
for (Saml2X509Credential key : decryptionCredentials) {
Decrypter decrypter = getDecrypter(key);
try {
return (NameID) decrypter.decrypt(assertion);
}
catch (DecryptionException e) {
last = authException(DECRYPTION_ERROR, e.getMessage(), e);
}
}
throw last;
} }
private List<Saml2X509Credential> getDecryptionCredentials(Saml2AuthenticationToken token) { private List<Saml2X509Credential> getDecryptionCredentials(Saml2AuthenticationToken token) {
@ -488,4 +452,61 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
return result; return result;
} }
private String getUsername(Saml2AuthenticationToken token, Assertion assertion)
throws Saml2AuthenticationException {
String username = null;
Subject subject = assertion.getSubject();
if (subject == null) {
throw authException(SUBJECT_NOT_FOUND, "Assertion [" + assertion.getID() + "] is missing a subject");
}
if (subject.getNameID() != null) {
username = subject.getNameID().getValue();
}
else if (subject.getEncryptedID() != null) {
NameID nameId = decrypt(token, subject.getEncryptedID());
username = nameId.getValue();
}
if (username == null) {
throw authException(USERNAME_NOT_FOUND, "Assertion [" + assertion.getID() + "] is missing a user identifier");
}
return username;
}
private NameID decrypt(Saml2AuthenticationToken token, EncryptedID assertion)
throws Saml2AuthenticationException {
Saml2AuthenticationException last = null;
List<Saml2X509Credential> decryptionCredentials = getDecryptionCredentials(token);
if (decryptionCredentials.isEmpty()) {
throw authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
}
for (Saml2X509Credential key : decryptionCredentials) {
Decrypter decrypter = getDecrypter(key);
try {
return (NameID) decrypter.decrypt(assertion);
}
catch (DecryptionException e) {
last = authException(DECRYPTION_ERROR, e.getMessage(), e);
}
}
throw last;
}
private Saml2Error validationError(String code, String description) {
return new Saml2Error(code, description);
}
private Saml2AuthenticationException authException(String code, String description)
throws Saml2AuthenticationException {
return new Saml2AuthenticationException(validationError(code, description));
}
private Saml2AuthenticationException authException(String code, String description, Exception cause)
throws Saml2AuthenticationException {
return new Saml2AuthenticationException(validationError(code, description), cause);
}
} }