diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResult.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResult.java new file mode 100644 index 0000000000..3df1e7c3da --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResult.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.core; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; + +import org.springframework.util.Assert; + +/** + * A result emitted from a SAML 2.0 Response validation attempt + * + * @author Josh Cummings + * @since 5.4 + */ +public final class Saml2ResponseValidatorResult { + static final Saml2ResponseValidatorResult NO_ERRORS = new Saml2ResponseValidatorResult(Collections.emptyList()); + + private final Collection errors; + + private Saml2ResponseValidatorResult(Collection errors) { + Assert.notNull(errors, "errors cannot be null"); + this.errors = new ArrayList<>(errors); + } + + /** + * Say whether this result indicates success + * + * @return whether this result has errors + */ + public boolean hasErrors() { + return !this.errors.isEmpty(); + } + + /** + * Return error details regarding the validation attempt + * + * @return the collection of results in this result, if any; returns an empty list otherwise + */ + public Collection getErrors() { + return Collections.unmodifiableCollection(this.errors); + } + + /** + * Return a new {@link Saml2ResponseValidatorResult} that contains + * both the given {@link Saml2Error} and the errors from the result + * + * @param error the {@link Saml2Error} to append + * @return a new {@link Saml2ResponseValidatorResult} for further reporting + */ + public Saml2ResponseValidatorResult concat(Saml2Error error) { + Assert.notNull(error, "error cannot be null"); + Collection errors = new ArrayList<>(this.errors); + errors.add(error); + return failure(errors); + } + + /** + * Return a new {@link Saml2ResponseValidatorResult} that contains + * the errors from the given {@link Saml2ResponseValidatorResult} as well + * as this result. + * + * @param result the {@link Saml2ResponseValidatorResult} to merge with this one + * @return a new {@link Saml2ResponseValidatorResult} for further reporting + */ + public Saml2ResponseValidatorResult concat(Saml2ResponseValidatorResult result) { + Assert.notNull(result, "result cannot be null"); + Collection errors = new ArrayList<>(this.errors); + errors.addAll(result.getErrors()); + return failure(errors); + } + + /** + * Construct a successful {@link Saml2ResponseValidatorResult} + * + * @return an {@link Saml2ResponseValidatorResult} with no errors + */ + public static Saml2ResponseValidatorResult success() { + return NO_ERRORS; + } + + /** + * Construct a failure {@link Saml2ResponseValidatorResult} with the provided detail + * + * @param errors the list of errors + * @return an {@link Saml2ResponseValidatorResult} with the errors specified + */ + public static Saml2ResponseValidatorResult failure(Saml2Error... errors) { + return failure(Arrays.asList(errors)); + } + + /** + * Construct a failure {@link Saml2ResponseValidatorResult} with the provided detail + * + * @param errors the list of errors + * @return an {@link Saml2ResponseValidatorResult} with the errors specified + */ + public static Saml2ResponseValidatorResult failure(Collection errors) { + if (errors.isEmpty()) { + return NO_ERRORS; + } + + return new Saml2ResponseValidatorResult(errors); + } +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index e2c8bdb35c..25e4350a17 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -30,6 +30,7 @@ import java.util.Map; import java.util.Set; import java.util.function.Function; import javax.annotation.Nonnull; +import javax.xml.namespace.QName; import net.shibboleth.utilities.java.support.resolver.CriteriaSet; import net.shibboleth.utilities.java.support.xml.ParserPool; @@ -61,11 +62,14 @@ import org.opensaml.saml.saml2.assertion.StatementValidator; import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator; import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator; import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator; +import org.opensaml.saml.saml2.assertion.impl.DelegationRestrictionConditionValidator; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; import org.opensaml.saml.saml2.core.SubjectConfirmation; import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller; @@ -106,6 +110,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -126,6 +131,8 @@ import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_IS import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE; import static org.springframework.security.saml2.core.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA; import static org.springframework.security.saml2.core.Saml2ErrorCodes.SUBJECT_NOT_FOUND; +import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.failure; +import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success; import static org.springframework.util.Assert.notNull; /** @@ -191,16 +198,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion))); }; + private Converter assertionValidator = assertionToken -> { + ValidationContext context = createValidationContext(assertionToken); + return createDefaultAssertionValidator(context).convert(assertionToken); + }; + private Converter signatureTrustEngineConverter = new SignatureTrustEngineConverter(); - private Converter assertionValidatorConverter = - new SAML20AssertionValidatorConverter(); - private Collection conditionValidators = - Collections.singleton(new AudienceRestrictionConditionValidator()); - private Converter validationContextConverter = - new ValidationContextConverter(); private Converter decrypterConverter = new DecrypterConverter(); + /** * Creates an {@link OpenSamlAuthenticationProvider} */ @@ -212,30 +219,43 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } /** - * Set the the collection of {@link ConditionValidator}s used when validating an assertion. + * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML 2.0 Response. * - * @param conditionValidators the collection of validators to use + * You can still invoke the default validator by delgating to + * {@link #createDefaultAssertionValidator(ValidationContext)}, like so: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *  provider.setAssertionValidator(assertionToken -> {
+	 *		ValidationContext context = // ... build using authentication token
+	 *		Saml2ResponseValidatorResult result = createDefaultAssertionValidator(context)
+	 *			.convert(assertionToken)
+	 *		return result.concat(myCustomValiator.convert(assertionToken));
+	 *  });
+	 * 
+ * + * Consider taking a look at {@link #createValidationContext(AssertionToken)} to see how it + * constructs a {@link ValidationContext}. + * + * You can also use this method to configure the provider to use a different + * {@link ValidationContext} from the default, like so: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	ValidationContext context = // ...
+	 *	provider.setAssertionValidator(createDefaultAssertionValidator(context));
+	 * 
+ * + * It is not necessary to delegate to the default validator. You can safely replace it + * entirely with your own. Note that signature verification is performed as a separate + * step from this validator. + * + * @param assertionValidator * @since 5.4 */ - public void setConditionValidators( - Collection conditionValidators) { - - Assert.notEmpty(conditionValidators, "conditionValidators cannot be empty"); - this.conditionValidators = conditionValidators; - } - - /** - * Set the strategy for retrieving the {@link ValidationContext} used when - * validating an assertion. - * - * @param validationContextConverter the strategy to use - * @since 5.4 - */ - public void setValidationContextConverter( - Converter validationContextConverter) { - - Assert.notNull(validationContextConverter, "validationContextConverter cannot be empty"); - this.validationContextConverter = validationContextConverter; + public void setAssertionValidator(Converter assertionValidator) { + Assert.notNull(assertionValidator, "assertionValidator cannot be null"); + this.assertionValidator = assertionValidator; } /** @@ -322,7 +342,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } boolean responseSigned = response.isSigned(); - Map validationExceptions = validateResponse(token, response); + Saml2ResponseValidatorResult result = validateResponse(token, response); Decrypter decrypter = this.decrypterConverter.convert(token); List assertions = decryptAssertions(decrypter, response); @@ -330,37 +350,37 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " + "Please either sign the response or all of the assertions."); } - validationExceptions.putAll(validateAssertions(token, response)); + result = result.concat(validateAssertions(token, response)); Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions()); NameID nameId = decryptPrincipal(decrypter, firstAssertion); if (nameId == null || nameId.getValue() == null) { - validationExceptions.put(SUBJECT_NOT_FOUND, authException(SUBJECT_NOT_FOUND, - "Assertion [" + firstAssertion.getID() + "] is missing a subject")); + Saml2Error error = new Saml2Error(SUBJECT_NOT_FOUND, + "Assertion [" + firstAssertion.getID() + "] is missing a subject"); + result = result.concat(error); } - if (validationExceptions.isEmpty()) { + if (result.hasErrors()) { + Collection errors = result.getErrors(); + if (logger.isTraceEnabled()) { + logger.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]: " + + errors); + } else if (logger.isDebugEnabled()) { + logger.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]"); + } + Saml2Error first = errors.iterator().next(); + throw authException(first.getErrorCode(), first.getDescription()); + } else { if (logger.isDebugEnabled()) { logger.debug("Successfully processed SAML Response [" + response.getID() + "]"); } - } else { - if (logger.isTraceEnabled()) { - logger.debug("Found " + validationExceptions.size() + " validation errors in SAML response [" + response.getID() + "]: " + - validationExceptions.values()); - } else if (logger.isDebugEnabled()) { - logger.debug("Found " + validationExceptions.size() + " validation errors in SAML response [" + response.getID() + "]"); - } - } - - if (!validationExceptions.isEmpty()) { - throw validationExceptions.values().iterator().next(); } } - private Map validateResponse + private Saml2ResponseValidatorResult validateResponse (Saml2AuthenticationToken token, Response response) { - Map validationExceptions = new HashMap<>(); + Collection errors = new ArrayList<>(); String issuer = response.getIssuer().getValue(); if (response.isSigned()) { @@ -368,8 +388,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi try { profileValidator.validate(response.getSignature()); } catch (Exception e) { - validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE, - "Invalid signature for SAML Response [" + response.getID() + "]: ", e)); + errors.add(new Saml2Error(INVALID_SIGNATURE, + "Invalid signature for SAML Response [" + response.getID() + "]: ")); } try { @@ -378,12 +398,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi criteriaSet.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS))); criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), criteriaSet)) { - validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE, + errors.add(new Saml2Error(INVALID_SIGNATURE, "Invalid signature for SAML Response [" + response.getID() + "]")); } } catch (Exception e) { - validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE, - "Invalid signature for SAML Response [" + response.getID() + "]: ", e)); + errors.add(new Saml2Error(INVALID_SIGNATURE, + "Invalid signature for SAML Response [" + response.getID() + "]: ")); } } @@ -391,16 +411,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); if (StringUtils.hasText(destination) && !destination.equals(location)) { String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]"; - validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message)); + errors.add(new Saml2Error(INVALID_DESTINATION, message)); } String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId(); if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) { String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID()); - validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message)); + errors.add(new Saml2Error(INVALID_ISSUER, message)); } - return validationExceptions; + return failure(errors); } private List decryptAssertions @@ -418,41 +438,35 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return response.getAssertions(); } - private Map validateAssertions + private Saml2ResponseValidatorResult validateAssertions (Saml2AuthenticationToken token, Response response) { List assertions = response.getAssertions(); if (assertions.isEmpty()) { throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response."); } - Map validationExceptions = new LinkedHashMap<>(); + Saml2ResponseValidatorResult result = success(); if (logger.isDebugEnabled()) { logger.debug("Validating " + assertions.size() + " assertions"); } - Tuple tuple = new Tuple(token, response); - SAML20AssertionValidator validator = this.assertionValidatorConverter.convert(tuple); - ValidationContext context = this.validationContextConverter.convert(tuple); + ValidationContext signatureContext = new ValidationContext + (Collections.singletonMap(SIGNATURE_REQUIRED, false)); // check already performed + SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(token); + Converter signatureValidator = + createDefaultAssertionValidator(INVALID_SIGNATURE, + SAML20AssertionValidators.createSignatureValidator(engine), signatureContext); for (Assertion assertion : assertions) { if (logger.isTraceEnabled()) { logger.trace("Validating assertion " + assertion.getID()); } - try { - if (validator.validate(assertion, context) != ValidationResult.VALID) { - String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", - assertion.getID(), ((Response) assertion.getParent()).getID(), - context.getValidationFailureMessage()); - validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message)); - } - } catch (Exception e) { - String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", - assertion.getID(), ((Response) assertion.getParent()).getID(), - e.getMessage()); - validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message, e)); - } + AssertionToken assertionToken = new AssertionToken(assertion, token); + result = result + .concat(signatureValidator.convert(assertionToken)) + .concat(this.assertionValidator.convert(assertionToken)); } - return validationExceptions; + return result; } private boolean isSigned(boolean responseSigned, List assertions) { @@ -561,45 +575,111 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } } - private class ValidationContextConverter implements Converter { + public static Converter + createDefaultAssertionValidator(ValidationContext context) { - @Override - public ValidationContext convert(Tuple tuple) { - String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId(); - String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); - Map params = new HashMap<>(); - params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis()); - params.put(COND_VALID_AUDIENCES, singleton(audience)); - params.put(SC_VALID_RECIPIENTS, singleton(recipient)); - params.put(SIGNATURE_REQUIRED, false); // this verification is performed earlier - return new ValidationContext(params); - } + return createDefaultAssertionValidator(INVALID_ASSERTION, + SAML20AssertionValidators.createAttributeValidator(), context); } - private class SAML20AssertionValidatorConverter implements Converter { - private final Collection subjects = new ArrayList<>(); - private final Collection statements = new ArrayList<>(); - private final SignaturePrevalidator validator = new SAMLSignatureProfileValidator(); + private static Converter + createDefaultAssertionValidator(String errorCode, SAML20AssertionValidator validator, ValidationContext context) { - SAML20AssertionValidatorConverter() { - this.subjects.add(new BearerSubjectConfirmationValidator() { + return assertionToken -> { + Assertion assertion = assertionToken.assertion; + try { + ValidationResult result = validator.validate(assertion, context); + if (result == ValidationResult.VALID) { + return success(); + } + } catch (Exception e) { + String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", + assertion.getID(), ((Response) assertion.getParent()).getID(), + e.getMessage()); + return failure(new Saml2Error(errorCode, message)); + } + String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", + assertion.getID(), ((Response) assertion.getParent()).getID(), + context.getValidationFailureMessage()); + return failure(new Saml2Error(errorCode, message)); + }; + } + + private ValidationContext createValidationContext(AssertionToken assertionToken) { + String audience = assertionToken.token.getRelyingPartyRegistration().getEntityId(); + String recipient = assertionToken.token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); + Map params = new HashMap<>(); + params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis()); + params.put(COND_VALID_AUDIENCES, singleton(audience)); + params.put(SC_VALID_RECIPIENTS, singleton(recipient)); + return new ValidationContext(params); + } + + private static class SAML20AssertionValidators { + private static final Collection conditions = new ArrayList<>(); + private static final Collection subjects = new ArrayList<>(); + private static final Collection statements = new ArrayList<>(); + private static final SignaturePrevalidator validator = new SAMLSignatureProfileValidator(); + + static { + conditions.add(new AudienceRestrictionConditionValidator()); + conditions.add(new DelegationRestrictionConditionValidator()); + conditions.add(new ConditionValidator() { + @Nonnull + @Override + public QName getServicedCondition() { + return OneTimeUse.DEFAULT_ELEMENT_NAME; + } + + @Nonnull + @Override + public ValidationResult validate(Condition condition, Assertion assertion, ValidationContext context) { + // applications should validate their own OneTimeUse conditions + return ValidationResult.VALID; + } + }); + subjects.add(new BearerSubjectConfirmationValidator() { @Nonnull @Override protected ValidationResult validateAddress(@Nonnull SubjectConfirmation confirmation, @Nonnull Assertion assertion, @Nonnull ValidationContext context) { - // skipping address validation - gh-7514 + // applications should validate their own addresses - gh-7514 return ValidationResult.VALID; } }); } - @Override - public SAML20AssertionValidator convert(Tuple tuple) { - Collection conditions = - OpenSamlAuthenticationProvider.this.conditionValidators; - return new SAML20AssertionValidator(conditions, this.subjects, this.statements, - OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(tuple.authentication), - this.validator); + static SAML20AssertionValidator createAttributeValidator() { + return new SAML20AssertionValidator(conditions, subjects, statements, null, null) { + @Nonnull + @Override + protected ValidationResult validateSignature(Assertion token, ValidationContext context) { + return ValidationResult.VALID; + } + }; + } + + static SAML20AssertionValidator createSignatureValidator(SignatureTrustEngine engine) { + return new SAML20AssertionValidator(new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + engine, validator) { + @Nonnull + @Override + protected ValidationResult validateConditions(Assertion assertion, ValidationContext context) { + return ValidationResult.VALID; + } + + @Nonnull + @Override + protected ValidationResult validateSubjectConfirmation(Assertion assertion, ValidationContext context) { + return ValidationResult.VALID; + } + + @Nonnull + @Override + protected ValidationResult validateStatements(Assertion assertion, ValidationContext context) { + return ValidationResult.VALID; + } + }; } } @@ -643,25 +723,25 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } /** - * A tuple containing the authentication token and the associated OpenSAML {@link Response}. + * A tuple containing an OpenSAML {@link Assertion} and its associated authentication token. * * @since 5.4 */ - public static class Tuple { - private final Saml2AuthenticationToken authentication; - private final Response response; + public static class AssertionToken { + private final Saml2AuthenticationToken token; + private final Assertion assertion; - private Tuple(Saml2AuthenticationToken authentication, Response response) { - this.authentication = authentication; - this.response = response; + private AssertionToken(Assertion assertion, Saml2AuthenticationToken token) { + this.token = token; + this.assertion = assertion; } - public Saml2AuthenticationToken getAuthentication() { - return this.authentication; + public Assertion getAssertion() { + return this.assertion; } - public Response getResponse() { - return this.response; + public Saml2AuthenticationToken getToken() { + return this.token; } } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResultTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResultTests.java new file mode 100644 index 0000000000..fa96940f13 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResultTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.saml2.core; + +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for verifying {@link Saml2ResponseValidatorResult} + * + * @author Josh Cummings + */ +public class Saml2ResponseValidatorResultTests { + private static final Saml2Error DETAIL = new Saml2Error( + "error", "description"); + + @Test + public void successWhenInvokedThenReturnsSuccessfulResult() { + Saml2ResponseValidatorResult success = Saml2ResponseValidatorResult.success(); + assertThat(success.hasErrors()).isFalse(); + } + + @Test + public void failureWhenInvokedWithDetailReturnsFailureResultIncludingDetail() { + Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL); + + assertThat(failure.hasErrors()).isTrue(); + assertThat(failure.getErrors()).containsExactly(DETAIL); + } + + @Test + public void failureWhenInvokedWithMultipleDetailsReturnsFailureResultIncludingAll() { + Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL, DETAIL); + + assertThat(failure.hasErrors()).isTrue(); + assertThat(failure.getErrors()).containsExactly(DETAIL, DETAIL); + } + + @Test + public void concatErrorWhenInvokedThenReturnsCopyContainingAll() { + Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL); + Saml2ResponseValidatorResult added = failure.concat(DETAIL); + + assertThat(added.hasErrors()).isTrue(); + assertThat(added.getErrors()).containsExactly(DETAIL, DETAIL); + assertThat(failure).isNotSameAs(added); + } + + @Test + public void concatResultWhenInvokedThenReturnsCopyContainingAll() { + Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL); + Saml2ResponseValidatorResult merged = failure + .concat(failure) + .concat(failure); + + assertThat(merged.hasErrors()).isTrue(); + assertThat(merged.getErrors()).containsExactly(DETAIL, DETAIL, DETAIL); + assertThat(failure).isNotSameAs(merged); + } + + @Test + public void concatErrorWhenNullThenIllegalArgument() { + assertThatThrownBy(() -> Saml2ResponseValidatorResult.failure(DETAIL) + .concat((Saml2Error) null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void concatResultWhenNullThenIllegalArgument() { + assertThatThrownBy(() -> Saml2ResponseValidatorResult.failure(DETAIL) + .concat((Saml2ResponseValidatorResult) null)) + .isInstanceOf(IllegalArgumentException.class); + } +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index 2b7e885c1e..d438f39e63 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -45,12 +45,9 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.io.Marshaller; import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.saml.common.assertion.ValidationContext; -import org.opensaml.saml.common.assertion.ValidationResult; -import org.opensaml.saml.saml2.assertion.impl.OneTimeUseConditionValidator; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.AttributeValue; -import org.opensaml.saml.saml2.core.Condition; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.NameID; @@ -60,13 +57,17 @@ import org.w3c.dom.Document; import org.w3c.dom.Element; import org.xml.sax.InputSource; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.Authentication; import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; import org.springframework.security.saml2.credentials.Saml2X509Credential; import static java.util.Collections.singleton; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -76,11 +77,14 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getB import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory; import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED; +import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION; +import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential; +import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultAssertionValidator; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted; @@ -365,10 +369,13 @@ public class OpenSamlAuthenticationProviderTests { } @Test - public void authenticateWhenConditionValidatorsCustomizedThenUses() throws Exception { - OneTimeUseConditionValidator validator = mock(OneTimeUseConditionValidator.class); + public void authenticateWhenDelegatingToDefaultAssertionValidatorThenUses() { OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setConditionValidators(Collections.singleton(validator)); + provider.setAssertionValidator(assertionToken -> { + ValidationContext context = new ValidationContext(); + return createDefaultAssertionValidator(context).convert(assertionToken) + .concat(new Saml2Error("wrong error", "wrong error")); + }); Response response = response(); Assertion assertion = assertion(); OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME); @@ -376,11 +383,46 @@ public class OpenSamlAuthenticationProviderTests { response.getAssertions().add(assertion); signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); - when(validator.getServicedCondition()).thenReturn(OneTimeUse.DEFAULT_ELEMENT_NAME); - when(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class))) - .thenReturn(ValidationResult.VALID); + assertThatThrownBy(() -> provider.authenticate(token)) + .isInstanceOf(Saml2AuthenticationException.class) + .hasFieldOrPropertyWithValue("error.errorCode", INVALID_ASSERTION); + } + + @Test + public void authenticateWhenCustomAssertionValidatorThenUses() { + Converter validator = + mock(Converter.class); + OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); + provider.setAssertionValidator(assertionToken -> { + ValidationContext context = new ValidationContext( + Collections.singletonMap(SC_VALID_RECIPIENTS, singleton(DESTINATION))); + return createDefaultAssertionValidator(context).convert(assertionToken) + .concat(validator.convert(assertionToken)); + }); + Response response = response(); + Assertion assertion = assertion(); + response.getAssertions().add(assertion); + signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + when(validator.convert(any(OpenSamlAuthenticationProvider.AssertionToken.class))) + .thenReturn(Saml2ResponseValidatorResult.success()); provider.authenticate(token); - verify(validator).validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)); + verify(validator).convert(any(OpenSamlAuthenticationProvider.AssertionToken.class)); + } + + @Test + public void authenticateWhenDefaultConditionValidatorNotUsedThenSignatureStillChecked() { + OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); + provider.setAssertionValidator(assertionToken -> Saml2ResponseValidatorResult.success()); + Response response = response(); + Assertion assertion = assertion(); + signed(assertion, relyingPartyDecryptingCredential(), RELYING_PARTY_ENTITY_ID); // broken signature + response.getAssertions().add(assertion); + signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + assertThatThrownBy(() -> provider.authenticate(token)) + .isInstanceOf(Saml2AuthenticationException.class) + .hasFieldOrPropertyWithValue("error.errorCode", INVALID_SIGNATURE); } @Test @@ -391,7 +433,7 @@ public class OpenSamlAuthenticationProviderTests { ValidationContext context = mock(ValidationContext.class); when(context.getStaticParameters()).thenReturn(parameters); OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setValidationContextConverter(tuple -> context); + provider.setAssertionValidator(assertionToken -> createDefaultAssertionValidator(context).convert(assertionToken)); Response response = response(); Assertion assertion = assertion(); response.getAssertions().add(assertion); @@ -402,17 +444,8 @@ public class OpenSamlAuthenticationProviderTests { } @Test - public void setValidationContextConverterWhenNullThenIllegalArgument() { - assertThatCode(() -> this.provider.setValidationContextConverter(null)) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void setConditionValidatorsWhenNullOrEmptyThenIllegalArgument() { - assertThatCode(() -> this.provider.setConditionValidators(null)) - .isInstanceOf(IllegalArgumentException.class); - - assertThatCode(() -> this.provider.setConditionValidators(Collections.emptyList())) + public void setAssertionValidatorWhenNullThenIllegalArgument() { + assertThatCode(() -> this.provider.setAssertionValidator(null)) .isInstanceOf(IllegalArgumentException.class); } diff --git a/samples/boot/saml2login/src/integration-test/java/org/springframework/security/saml2/provider/service/authentication/Saml2LoginIntegrationTests.java b/samples/boot/saml2login/src/integration-test/java/org/springframework/security/saml2/provider/service/authentication/Saml2LoginIntegrationTests.java index d51d2cdc4b..176826fea4 100644 --- a/samples/boot/saml2login/src/integration-test/java/org/springframework/security/saml2/provider/service/authentication/Saml2LoginIntegrationTests.java +++ b/samples/boot/saml2login/src/integration-test/java/org/springframework/security/saml2/provider/service/authentication/Saml2LoginIntegrationTests.java @@ -242,7 +242,7 @@ public class Saml2LoginIntegrationTests { sendResponse(response, "/login?error") .andExpect( saml2AuthenticationExceptionMatcher( - "invalid_assertion", + "invalid_signature", containsString("Invalid assertion [assertion] for SAML response") ) ); @@ -288,9 +288,9 @@ public class Saml2LoginIntegrationTests { .andExpect(unauthenticated()) .andExpect( saml2AuthenticationExceptionMatcher( - "invalid_issuer", + "invalid_signature", containsString( - "Invalid issuer" + "Invalid signature" ) ) );