Generalize SAML 2.0 Assertion Validation Support

Closes gh-8970
This commit is contained in:
Josh Cummings 2020-08-18 11:48:08 -06:00
parent 1069e91645
commit 7b3dda161b
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
5 changed files with 457 additions and 134 deletions

View File

@ -0,0 +1,121 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.core;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import org.springframework.util.Assert;
/**
* A result emitted from a SAML 2.0 Response validation attempt
*
* @author Josh Cummings
* @since 5.4
*/
public final class Saml2ResponseValidatorResult {
static final Saml2ResponseValidatorResult NO_ERRORS = new Saml2ResponseValidatorResult(Collections.emptyList());
private final Collection<Saml2Error> errors;
private Saml2ResponseValidatorResult(Collection<Saml2Error> errors) {
Assert.notNull(errors, "errors cannot be null");
this.errors = new ArrayList<>(errors);
}
/**
* Say whether this result indicates success
*
* @return whether this result has errors
*/
public boolean hasErrors() {
return !this.errors.isEmpty();
}
/**
* Return error details regarding the validation attempt
*
* @return the collection of results in this result, if any; returns an empty list otherwise
*/
public Collection<Saml2Error> getErrors() {
return Collections.unmodifiableCollection(this.errors);
}
/**
* Return a new {@link Saml2ResponseValidatorResult} that contains
* both the given {@link Saml2Error} and the errors from the result
*
* @param error the {@link Saml2Error} to append
* @return a new {@link Saml2ResponseValidatorResult} for further reporting
*/
public Saml2ResponseValidatorResult concat(Saml2Error error) {
Assert.notNull(error, "error cannot be null");
Collection<Saml2Error> errors = new ArrayList<>(this.errors);
errors.add(error);
return failure(errors);
}
/**
* Return a new {@link Saml2ResponseValidatorResult} that contains
* the errors from the given {@link Saml2ResponseValidatorResult} as well
* as this result.
*
* @param result the {@link Saml2ResponseValidatorResult} to merge with this one
* @return a new {@link Saml2ResponseValidatorResult} for further reporting
*/
public Saml2ResponseValidatorResult concat(Saml2ResponseValidatorResult result) {
Assert.notNull(result, "result cannot be null");
Collection<Saml2Error> errors = new ArrayList<>(this.errors);
errors.addAll(result.getErrors());
return failure(errors);
}
/**
* Construct a successful {@link Saml2ResponseValidatorResult}
*
* @return an {@link Saml2ResponseValidatorResult} with no errors
*/
public static Saml2ResponseValidatorResult success() {
return NO_ERRORS;
}
/**
* Construct a failure {@link Saml2ResponseValidatorResult} with the provided detail
*
* @param errors the list of errors
* @return an {@link Saml2ResponseValidatorResult} with the errors specified
*/
public static Saml2ResponseValidatorResult failure(Saml2Error... errors) {
return failure(Arrays.asList(errors));
}
/**
* Construct a failure {@link Saml2ResponseValidatorResult} with the provided detail
*
* @param errors the list of errors
* @return an {@link Saml2ResponseValidatorResult} with the errors specified
*/
public static Saml2ResponseValidatorResult failure(Collection<Saml2Error> errors) {
if (errors.isEmpty()) {
return NO_ERRORS;
}
return new Saml2ResponseValidatorResult(errors);
}
}

View File

@ -30,6 +30,7 @@ import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import javax.annotation.Nonnull;
import javax.xml.namespace.QName;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.xml.ParserPool;
@ -61,11 +62,14 @@ import org.opensaml.saml.saml2.assertion.StatementValidator;
import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator;
import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator;
import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator;
import org.opensaml.saml.saml2.assertion.impl.DelegationRestrictionConditionValidator;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.Condition;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.OneTimeUse;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.SubjectConfirmation;
import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller;
@ -106,6 +110,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@ -126,6 +131,8 @@ import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_IS
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.SUBJECT_NOT_FOUND;
import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.failure;
import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success;
import static org.springframework.util.Assert.notNull;
/**
@ -191,16 +198,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
};
private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator = assertionToken -> {
ValidationContext context = createValidationContext(assertionToken);
return createDefaultAssertionValidator(context).convert(assertionToken);
};
private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter =
new SignatureTrustEngineConverter();
private Converter<Tuple, SAML20AssertionValidator> assertionValidatorConverter =
new SAML20AssertionValidatorConverter();
private Collection<ConditionValidator> conditionValidators =
Collections.singleton(new AudienceRestrictionConditionValidator());
private Converter<Tuple, ValidationContext> validationContextConverter =
new ValidationContextConverter();
private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
/**
* Creates an {@link OpenSamlAuthenticationProvider}
*/
@ -212,30 +219,43 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
}
/**
* Set the the collection of {@link ConditionValidator}s used when validating an assertion.
* Set the {@link Converter} to use for validating each {@link Assertion} in the SAML 2.0 Response.
*
* @param conditionValidators the collection of validators to use
* You can still invoke the default validator by delgating to
* {@link #createDefaultAssertionValidator(ValidationContext)}, like so:
*
* <pre>
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* provider.setAssertionValidator(assertionToken -> {
* ValidationContext context = // ... build using authentication token
* Saml2ResponseValidatorResult result = createDefaultAssertionValidator(context)
* .convert(assertionToken)
* return result.concat(myCustomValiator.convert(assertionToken));
* });
* </pre>
*
* Consider taking a look at {@link #createValidationContext(AssertionToken)} to see how it
* constructs a {@link ValidationContext}.
*
* You can also use this method to configure the provider to use a different
* {@link ValidationContext} from the default, like so:
*
* <pre>
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* ValidationContext context = // ...
* provider.setAssertionValidator(createDefaultAssertionValidator(context));
* </pre>
*
* It is not necessary to delegate to the default validator. You can safely replace it
* entirely with your own. Note that signature verification is performed as a separate
* step from this validator.
*
* @param assertionValidator
* @since 5.4
*/
public void setConditionValidators(
Collection<ConditionValidator> conditionValidators) {
Assert.notEmpty(conditionValidators, "conditionValidators cannot be empty");
this.conditionValidators = conditionValidators;
}
/**
* Set the strategy for retrieving the {@link ValidationContext} used when
* validating an assertion.
*
* @param validationContextConverter the strategy to use
* @since 5.4
*/
public void setValidationContextConverter(
Converter<Tuple, ValidationContext> validationContextConverter) {
Assert.notNull(validationContextConverter, "validationContextConverter cannot be empty");
this.validationContextConverter = validationContextConverter;
public void setAssertionValidator(Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator) {
Assert.notNull(assertionValidator, "assertionValidator cannot be null");
this.assertionValidator = assertionValidator;
}
/**
@ -322,7 +342,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
}
boolean responseSigned = response.isSigned();
Map<String, Saml2AuthenticationException> validationExceptions = validateResponse(token, response);
Saml2ResponseValidatorResult result = validateResponse(token, response);
Decrypter decrypter = this.decrypterConverter.convert(token);
List<Assertion> assertions = decryptAssertions(decrypter, response);
@ -330,37 +350,37 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " +
"Please either sign the response or all of the assertions.");
}
validationExceptions.putAll(validateAssertions(token, response));
result = result.concat(validateAssertions(token, response));
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
NameID nameId = decryptPrincipal(decrypter, firstAssertion);
if (nameId == null || nameId.getValue() == null) {
validationExceptions.put(SUBJECT_NOT_FOUND, authException(SUBJECT_NOT_FOUND,
"Assertion [" + firstAssertion.getID() + "] is missing a subject"));
Saml2Error error = new Saml2Error(SUBJECT_NOT_FOUND,
"Assertion [" + firstAssertion.getID() + "] is missing a subject");
result = result.concat(error);
}
if (validationExceptions.isEmpty()) {
if (result.hasErrors()) {
Collection<Saml2Error> errors = result.getErrors();
if (logger.isTraceEnabled()) {
logger.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]: " +
errors);
} else if (logger.isDebugEnabled()) {
logger.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]");
}
Saml2Error first = errors.iterator().next();
throw authException(first.getErrorCode(), first.getDescription());
} else {
if (logger.isDebugEnabled()) {
logger.debug("Successfully processed SAML Response [" + response.getID() + "]");
}
} else {
if (logger.isTraceEnabled()) {
logger.debug("Found " + validationExceptions.size() + " validation errors in SAML response [" + response.getID() + "]: " +
validationExceptions.values());
} else if (logger.isDebugEnabled()) {
logger.debug("Found " + validationExceptions.size() + " validation errors in SAML response [" + response.getID() + "]");
}
}
if (!validationExceptions.isEmpty()) {
throw validationExceptions.values().iterator().next();
}
}
private Map<String, Saml2AuthenticationException> validateResponse
private Saml2ResponseValidatorResult validateResponse
(Saml2AuthenticationToken token, Response response) {
Map<String, Saml2AuthenticationException> validationExceptions = new HashMap<>();
Collection<Saml2Error> errors = new ArrayList<>();
String issuer = response.getIssuer().getValue();
if (response.isSigned()) {
@ -368,8 +388,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
try {
profileValidator.validate(response.getSignature());
} catch (Exception e) {
validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]: ", e));
errors.add(new Saml2Error(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]: "));
}
try {
@ -378,12 +398,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
criteriaSet.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), criteriaSet)) {
validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
errors.add(new Saml2Error(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]"));
}
} catch (Exception e) {
validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]: ", e));
errors.add(new Saml2Error(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]: "));
}
}
@ -391,16 +411,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
if (StringUtils.hasText(destination) && !destination.equals(location)) {
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message));
errors.add(new Saml2Error(INVALID_DESTINATION, message));
}
String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId();
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message));
errors.add(new Saml2Error(INVALID_ISSUER, message));
}
return validationExceptions;
return failure(errors);
}
private List<Assertion> decryptAssertions
@ -418,41 +438,35 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return response.getAssertions();
}
private Map<String, Saml2AuthenticationException> validateAssertions
private Saml2ResponseValidatorResult validateAssertions
(Saml2AuthenticationToken token, Response response) {
List<Assertion> assertions = response.getAssertions();
if (assertions.isEmpty()) {
throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
}
Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>();
Saml2ResponseValidatorResult result = success();
if (logger.isDebugEnabled()) {
logger.debug("Validating " + assertions.size() + " assertions");
}
Tuple tuple = new Tuple(token, response);
SAML20AssertionValidator validator = this.assertionValidatorConverter.convert(tuple);
ValidationContext context = this.validationContextConverter.convert(tuple);
ValidationContext signatureContext = new ValidationContext
(Collections.singletonMap(SIGNATURE_REQUIRED, false)); // check already performed
SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(token);
Converter<AssertionToken, Saml2ResponseValidatorResult> signatureValidator =
createDefaultAssertionValidator(INVALID_SIGNATURE,
SAML20AssertionValidators.createSignatureValidator(engine), signatureContext);
for (Assertion assertion : assertions) {
if (logger.isTraceEnabled()) {
logger.trace("Validating assertion " + assertion.getID());
}
try {
if (validator.validate(assertion, context) != ValidationResult.VALID) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
assertion.getID(), ((Response) assertion.getParent()).getID(),
context.getValidationFailureMessage());
validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message));
}
} catch (Exception e) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
assertion.getID(), ((Response) assertion.getParent()).getID(),
e.getMessage());
validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message, e));
}
AssertionToken assertionToken = new AssertionToken(assertion, token);
result = result
.concat(signatureValidator.convert(assertionToken))
.concat(this.assertionValidator.convert(assertionToken));
}
return validationExceptions;
return result;
}
private boolean isSigned(boolean responseSigned, List<Assertion> assertions) {
@ -561,45 +575,111 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
}
}
private class ValidationContextConverter implements Converter<Tuple, ValidationContext> {
public static Converter<AssertionToken, Saml2ResponseValidatorResult>
createDefaultAssertionValidator(ValidationContext context) {
@Override
public ValidationContext convert(Tuple tuple) {
String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId();
String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
Map<String, Object> params = new HashMap<>();
params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
params.put(COND_VALID_AUDIENCES, singleton(audience));
params.put(SC_VALID_RECIPIENTS, singleton(recipient));
params.put(SIGNATURE_REQUIRED, false); // this verification is performed earlier
return new ValidationContext(params);
}
return createDefaultAssertionValidator(INVALID_ASSERTION,
SAML20AssertionValidators.createAttributeValidator(), context);
}
private class SAML20AssertionValidatorConverter implements Converter<Tuple, SAML20AssertionValidator> {
private final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
private final Collection<StatementValidator> statements = new ArrayList<>();
private final SignaturePrevalidator validator = new SAMLSignatureProfileValidator();
private static Converter<AssertionToken, Saml2ResponseValidatorResult>
createDefaultAssertionValidator(String errorCode, SAML20AssertionValidator validator, ValidationContext context) {
SAML20AssertionValidatorConverter() {
this.subjects.add(new BearerSubjectConfirmationValidator() {
return assertionToken -> {
Assertion assertion = assertionToken.assertion;
try {
ValidationResult result = validator.validate(assertion, context);
if (result == ValidationResult.VALID) {
return success();
}
} catch (Exception e) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
assertion.getID(), ((Response) assertion.getParent()).getID(),
e.getMessage());
return failure(new Saml2Error(errorCode, message));
}
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
assertion.getID(), ((Response) assertion.getParent()).getID(),
context.getValidationFailureMessage());
return failure(new Saml2Error(errorCode, message));
};
}
private ValidationContext createValidationContext(AssertionToken assertionToken) {
String audience = assertionToken.token.getRelyingPartyRegistration().getEntityId();
String recipient = assertionToken.token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
Map<String, Object> params = new HashMap<>();
params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
params.put(COND_VALID_AUDIENCES, singleton(audience));
params.put(SC_VALID_RECIPIENTS, singleton(recipient));
return new ValidationContext(params);
}
private static class SAML20AssertionValidators {
private static final Collection<ConditionValidator> conditions = new ArrayList<>();
private static final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
private static final Collection<StatementValidator> statements = new ArrayList<>();
private static final SignaturePrevalidator validator = new SAMLSignatureProfileValidator();
static {
conditions.add(new AudienceRestrictionConditionValidator());
conditions.add(new DelegationRestrictionConditionValidator());
conditions.add(new ConditionValidator() {
@Nonnull
@Override
public QName getServicedCondition() {
return OneTimeUse.DEFAULT_ELEMENT_NAME;
}
@Nonnull
@Override
public ValidationResult validate(Condition condition, Assertion assertion, ValidationContext context) {
// applications should validate their own OneTimeUse conditions
return ValidationResult.VALID;
}
});
subjects.add(new BearerSubjectConfirmationValidator() {
@Nonnull
@Override
protected ValidationResult validateAddress(@Nonnull SubjectConfirmation confirmation,
@Nonnull Assertion assertion, @Nonnull ValidationContext context) {
// skipping address validation - gh-7514
// applications should validate their own addresses - gh-7514
return ValidationResult.VALID;
}
});
}
@Override
public SAML20AssertionValidator convert(Tuple tuple) {
Collection<ConditionValidator> conditions =
OpenSamlAuthenticationProvider.this.conditionValidators;
return new SAML20AssertionValidator(conditions, this.subjects, this.statements,
OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(tuple.authentication),
this.validator);
static SAML20AssertionValidator createAttributeValidator() {
return new SAML20AssertionValidator(conditions, subjects, statements, null, null) {
@Nonnull
@Override
protected ValidationResult validateSignature(Assertion token, ValidationContext context) {
return ValidationResult.VALID;
}
};
}
static SAML20AssertionValidator createSignatureValidator(SignatureTrustEngine engine) {
return new SAML20AssertionValidator(new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
engine, validator) {
@Nonnull
@Override
protected ValidationResult validateConditions(Assertion assertion, ValidationContext context) {
return ValidationResult.VALID;
}
@Nonnull
@Override
protected ValidationResult validateSubjectConfirmation(Assertion assertion, ValidationContext context) {
return ValidationResult.VALID;
}
@Nonnull
@Override
protected ValidationResult validateStatements(Assertion assertion, ValidationContext context) {
return ValidationResult.VALID;
}
};
}
}
@ -643,25 +723,25 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
}
/**
* A tuple containing the authentication token and the associated OpenSAML {@link Response}.
* A tuple containing an OpenSAML {@link Assertion} and its associated authentication token.
*
* @since 5.4
*/
public static class Tuple {
private final Saml2AuthenticationToken authentication;
private final Response response;
public static class AssertionToken {
private final Saml2AuthenticationToken token;
private final Assertion assertion;
private Tuple(Saml2AuthenticationToken authentication, Response response) {
this.authentication = authentication;
this.response = response;
private AssertionToken(Assertion assertion, Saml2AuthenticationToken token) {
this.token = token;
this.assertion = assertion;
}
public Saml2AuthenticationToken getAuthentication() {
return this.authentication;
public Assertion getAssertion() {
return this.assertion;
}
public Response getResponse() {
return this.response;
public Saml2AuthenticationToken getToken() {
return this.token;
}
}
}

View File

@ -0,0 +1,89 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.core;
import org.junit.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for verifying {@link Saml2ResponseValidatorResult}
*
* @author Josh Cummings
*/
public class Saml2ResponseValidatorResultTests {
private static final Saml2Error DETAIL = new Saml2Error(
"error", "description");
@Test
public void successWhenInvokedThenReturnsSuccessfulResult() {
Saml2ResponseValidatorResult success = Saml2ResponseValidatorResult.success();
assertThat(success.hasErrors()).isFalse();
}
@Test
public void failureWhenInvokedWithDetailReturnsFailureResultIncludingDetail() {
Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL);
assertThat(failure.hasErrors()).isTrue();
assertThat(failure.getErrors()).containsExactly(DETAIL);
}
@Test
public void failureWhenInvokedWithMultipleDetailsReturnsFailureResultIncludingAll() {
Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL, DETAIL);
assertThat(failure.hasErrors()).isTrue();
assertThat(failure.getErrors()).containsExactly(DETAIL, DETAIL);
}
@Test
public void concatErrorWhenInvokedThenReturnsCopyContainingAll() {
Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL);
Saml2ResponseValidatorResult added = failure.concat(DETAIL);
assertThat(added.hasErrors()).isTrue();
assertThat(added.getErrors()).containsExactly(DETAIL, DETAIL);
assertThat(failure).isNotSameAs(added);
}
@Test
public void concatResultWhenInvokedThenReturnsCopyContainingAll() {
Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL);
Saml2ResponseValidatorResult merged = failure
.concat(failure)
.concat(failure);
assertThat(merged.hasErrors()).isTrue();
assertThat(merged.getErrors()).containsExactly(DETAIL, DETAIL, DETAIL);
assertThat(failure).isNotSameAs(merged);
}
@Test
public void concatErrorWhenNullThenIllegalArgument() {
assertThatThrownBy(() -> Saml2ResponseValidatorResult.failure(DETAIL)
.concat((Saml2Error) null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void concatResultWhenNullThenIllegalArgument() {
assertThatThrownBy(() -> Saml2ResponseValidatorResult.failure(DETAIL)
.concat((Saml2ResponseValidatorResult) null))
.isInstanceOf(IllegalArgumentException.class);
}
}

View File

@ -45,12 +45,9 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.common.assertion.ValidationResult;
import org.opensaml.saml.saml2.assertion.impl.OneTimeUseConditionValidator;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AttributeValue;
import org.opensaml.saml.saml2.core.Condition;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.NameID;
@ -60,13 +57,17 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.xml.sax.InputSource;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import static java.util.Collections.singleton;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
@ -76,11 +77,14 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getB
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultAssertionValidator;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
@ -365,10 +369,13 @@ public class OpenSamlAuthenticationProviderTests {
}
@Test
public void authenticateWhenConditionValidatorsCustomizedThenUses() throws Exception {
OneTimeUseConditionValidator validator = mock(OneTimeUseConditionValidator.class);
public void authenticateWhenDelegatingToDefaultAssertionValidatorThenUses() {
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setConditionValidators(Collections.singleton(validator));
provider.setAssertionValidator(assertionToken -> {
ValidationContext context = new ValidationContext();
return createDefaultAssertionValidator(context).convert(assertionToken)
.concat(new Saml2Error("wrong error", "wrong error"));
});
Response response = response();
Assertion assertion = assertion();
OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME);
@ -376,11 +383,46 @@ public class OpenSamlAuthenticationProviderTests {
response.getAssertions().add(assertion);
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
when(validator.getServicedCondition()).thenReturn(OneTimeUse.DEFAULT_ELEMENT_NAME);
when(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)))
.thenReturn(ValidationResult.VALID);
assertThatThrownBy(() -> provider.authenticate(token))
.isInstanceOf(Saml2AuthenticationException.class)
.hasFieldOrPropertyWithValue("error.errorCode", INVALID_ASSERTION);
}
@Test
public void authenticateWhenCustomAssertionValidatorThenUses() {
Converter<OpenSamlAuthenticationProvider.AssertionToken, Saml2ResponseValidatorResult> validator =
mock(Converter.class);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setAssertionValidator(assertionToken -> {
ValidationContext context = new ValidationContext(
Collections.singletonMap(SC_VALID_RECIPIENTS, singleton(DESTINATION)));
return createDefaultAssertionValidator(context).convert(assertionToken)
.concat(validator.convert(assertionToken));
});
Response response = response();
Assertion assertion = assertion();
response.getAssertions().add(assertion);
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
when(validator.convert(any(OpenSamlAuthenticationProvider.AssertionToken.class)))
.thenReturn(Saml2ResponseValidatorResult.success());
provider.authenticate(token);
verify(validator).validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class));
verify(validator).convert(any(OpenSamlAuthenticationProvider.AssertionToken.class));
}
@Test
public void authenticateWhenDefaultConditionValidatorNotUsedThenSignatureStillChecked() {
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setAssertionValidator(assertionToken -> Saml2ResponseValidatorResult.success());
Response response = response();
Assertion assertion = assertion();
signed(assertion, relyingPartyDecryptingCredential(), RELYING_PARTY_ENTITY_ID); // broken signature
response.getAssertions().add(assertion);
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
assertThatThrownBy(() -> provider.authenticate(token))
.isInstanceOf(Saml2AuthenticationException.class)
.hasFieldOrPropertyWithValue("error.errorCode", INVALID_SIGNATURE);
}
@Test
@ -391,7 +433,7 @@ public class OpenSamlAuthenticationProviderTests {
ValidationContext context = mock(ValidationContext.class);
when(context.getStaticParameters()).thenReturn(parameters);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setValidationContextConverter(tuple -> context);
provider.setAssertionValidator(assertionToken -> createDefaultAssertionValidator(context).convert(assertionToken));
Response response = response();
Assertion assertion = assertion();
response.getAssertions().add(assertion);
@ -402,17 +444,8 @@ public class OpenSamlAuthenticationProviderTests {
}
@Test
public void setValidationContextConverterWhenNullThenIllegalArgument() {
assertThatCode(() -> this.provider.setValidationContextConverter(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setConditionValidatorsWhenNullOrEmptyThenIllegalArgument() {
assertThatCode(() -> this.provider.setConditionValidators(null))
.isInstanceOf(IllegalArgumentException.class);
assertThatCode(() -> this.provider.setConditionValidators(Collections.emptyList()))
public void setAssertionValidatorWhenNullThenIllegalArgument() {
assertThatCode(() -> this.provider.setAssertionValidator(null))
.isInstanceOf(IllegalArgumentException.class);
}

View File

@ -242,7 +242,7 @@ public class Saml2LoginIntegrationTests {
sendResponse(response, "/login?error")
.andExpect(
saml2AuthenticationExceptionMatcher(
"invalid_assertion",
"invalid_signature",
containsString("Invalid assertion [assertion] for SAML response")
)
);
@ -288,9 +288,9 @@ public class Saml2LoginIntegrationTests {
.andExpect(unauthenticated())
.andExpect(
saml2AuthenticationExceptionMatcher(
"invalid_issuer",
"invalid_signature",
containsString(
"Invalid issuer"
"Invalid signature"
)
)
);