diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index c5b5363b9b..e2c8bdb35c 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -21,13 +21,13 @@ import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.Consumer; import java.util.function.Function; import javax.annotation.Nonnull; @@ -193,10 +193,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private Converter signatureTrustEngineConverter = new SignatureTrustEngineConverter(); - private Converter assertionValidatorConverter = + private Converter assertionValidatorConverter = new SAML20AssertionValidatorConverter(); - private Converter validationContextConverter = - new ValidationContextConverter(params -> {}); + private Collection conditionValidators = + Collections.singleton(new AudienceRestrictionConditionValidator()); + private Converter validationContextConverter = + new ValidationContextConverter(); private Converter decrypterConverter = new DecrypterConverter(); /** @@ -209,6 +211,33 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi this.parserPool = this.registry.getParserPool(); } + /** + * Set the the collection of {@link ConditionValidator}s used when validating an assertion. + * + * @param conditionValidators the collection of validators to use + * @since 5.4 + */ + public void setConditionValidators( + Collection conditionValidators) { + + Assert.notEmpty(conditionValidators, "conditionValidators cannot be empty"); + this.conditionValidators = conditionValidators; + } + + /** + * Set the strategy for retrieving the {@link ValidationContext} used when + * validating an assertion. + * + * @param validationContextConverter the strategy to use + * @since 5.4 + */ + public void setValidationContextConverter( + Converter validationContextConverter) { + + Assert.notNull(validationContextConverter, "validationContextConverter cannot be empty"); + this.validationContextConverter = validationContextConverter; + } + /** * Sets the {@link Converter} used for extracting assertion attributes that * can be mapped to authorities. @@ -238,8 +267,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi */ public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) { this.responseTimeValidationSkew = responseTimeValidationSkew; - this.validationContextConverter = new ValidationContextConverter( - params -> params.put(CLOCK_SKEW, responseTimeValidationSkew.toMillis())); } /** @@ -303,7 +330,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " + "Please either sign the response or all of the assertions."); } - validationExceptions.putAll(validateAssertions(token, assertions)); + validationExceptions.putAll(validateAssertions(token, response)); Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions()); NameID nameId = decryptPrincipal(decrypter, firstAssertion); @@ -392,7 +419,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } private Map validateAssertions - (Saml2AuthenticationToken token, List assertions) { + (Saml2AuthenticationToken token, Response response) { + List assertions = response.getAssertions(); if (assertions.isEmpty()) { throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response."); } @@ -401,14 +429,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi if (logger.isDebugEnabled()) { logger.debug("Validating " + assertions.size() + " assertions"); } + + Tuple tuple = new Tuple(token, response); + SAML20AssertionValidator validator = this.assertionValidatorConverter.convert(tuple); + ValidationContext context = this.validationContextConverter.convert(tuple); for (Assertion assertion : assertions) { if (logger.isTraceEnabled()) { logger.trace("Validating assertion " + assertion.getID()); } try { - ValidationContext context = this.validationContextConverter.convert(token); - ValidationResult result = this.assertionValidatorConverter.convert(token).validate(assertion, context); - if (result != ValidationResult.VALID) { + if (validator.validate(assertion, context) != ValidationResult.VALID) { String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(), ((Response) assertion.getParent()).getID(), context.getValidationFailureMessage()); @@ -512,6 +542,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } private static class SignatureTrustEngineConverter implements Converter { + @Override public SignatureTrustEngine convert(Saml2AuthenticationToken token) { Set credentials = new HashSet<>(); @@ -530,35 +561,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } } - private static class ValidationContextConverter implements Converter { - Consumer> validationContextParametersConverter; - - ValidationContextConverter(Consumer> validationContextParametersConverter) { - this.validationContextParametersConverter = validationContextParametersConverter; - } + private class ValidationContextConverter implements Converter { @Override - public ValidationContext convert(Saml2AuthenticationToken token) { - String audience = token.getRelyingPartyRegistration().getEntityId(); - String recipient = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); + public ValidationContext convert(Tuple tuple) { + String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId(); + String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); Map params = new HashMap<>(); - params.put(CLOCK_SKEW, Duration.ofMinutes(5).toMillis()); + params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis()); params.put(COND_VALID_AUDIENCES, singleton(audience)); params.put(SC_VALID_RECIPIENTS, singleton(recipient)); params.put(SIGNATURE_REQUIRED, false); // this verification is performed earlier - this.validationContextParametersConverter.accept(params); return new ValidationContext(params); } } - private class SAML20AssertionValidatorConverter implements Converter { - private final Collection conditions = new ArrayList<>(); + private class SAML20AssertionValidatorConverter implements Converter { private final Collection subjects = new ArrayList<>(); private final Collection statements = new ArrayList<>(); private final SignaturePrevalidator validator = new SAMLSignatureProfileValidator(); SAML20AssertionValidatorConverter() { - this.conditions.add(new AudienceRestrictionConditionValidator()); this.subjects.add(new BearerSubjectConfirmationValidator() { @Nonnull @Override @@ -571,9 +594,11 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } @Override - public SAML20AssertionValidator convert(Saml2AuthenticationToken token) { - return new SAML20AssertionValidator(this.conditions, this.subjects, this.statements, - OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(token), + public SAML20AssertionValidator convert(Tuple tuple) { + Collection conditions = + OpenSamlAuthenticationProvider.this.conditionValidators; + return new SAML20AssertionValidator(conditions, this.subjects, this.statements, + OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(tuple.authentication), this.validator); } } @@ -616,4 +641,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return new Saml2AuthenticationException(validationError(code, description), cause); } + + /** + * A tuple containing the authentication token and the associated OpenSAML {@link Response}. + * + * @since 5.4 + */ + public static class Tuple { + private final Saml2AuthenticationToken authentication; + private final Response response; + + private Tuple(Saml2AuthenticationToken authentication, Response response) { + this.authentication = authentication; + this.response = response; + } + + public Saml2AuthenticationToken getAuthentication() { + return this.authentication; + } + + public Response getResponse() { + return this.response; + } + } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index 4851b6f15b..2b7e885c1e 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -23,9 +23,11 @@ import java.io.StringReader; import java.time.Instant; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import javax.xml.namespace.QName; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; @@ -42,12 +44,17 @@ import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.io.Marshaller; import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.saml.common.assertion.ValidationContext; +import org.opensaml.saml.common.assertion.ValidationResult; +import org.opensaml.saml.saml2.assertion.impl.OneTimeUseConditionValidator; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.AttributeValue; +import org.opensaml.saml.saml2.core.Condition; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; import org.w3c.dom.Document; import org.w3c.dom.Element; @@ -57,7 +64,9 @@ import org.springframework.security.core.Authentication; import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.credentials.Saml2X509Credential; +import static java.util.Collections.singleton; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -65,6 +74,8 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory; +import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; +import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential; @@ -353,6 +364,62 @@ public class OpenSamlAuthenticationProviderTests { objectOutputStream.flush(); } + @Test + public void authenticateWhenConditionValidatorsCustomizedThenUses() throws Exception { + OneTimeUseConditionValidator validator = mock(OneTimeUseConditionValidator.class); + OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); + provider.setConditionValidators(Collections.singleton(validator)); + Response response = response(); + Assertion assertion = assertion(); + OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME); + assertion.getConditions().getConditions().add(oneTimeUse); + response.getAssertions().add(assertion); + signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + when(validator.getServicedCondition()).thenReturn(OneTimeUse.DEFAULT_ELEMENT_NAME); + when(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class))) + .thenReturn(ValidationResult.VALID); + provider.authenticate(token); + verify(validator).validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)); + } + + @Test + public void authenticateWhenValidationContextCustomizedThenUsers() { + Map parameters = new HashMap<>(); + parameters.put(SC_VALID_RECIPIENTS, singleton(DESTINATION)); + parameters.put(SIGNATURE_REQUIRED, false); + ValidationContext context = mock(ValidationContext.class); + when(context.getStaticParameters()).thenReturn(parameters); + OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); + provider.setValidationContextConverter(tuple -> context); + Response response = response(); + Assertion assertion = assertion(); + response.getAssertions().add(assertion); + signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + provider.authenticate(token); + verify(context, atLeastOnce()).getStaticParameters(); + } + + @Test + public void setValidationContextConverterWhenNullThenIllegalArgument() { + assertThatCode(() -> this.provider.setValidationContextConverter(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setConditionValidatorsWhenNullOrEmptyThenIllegalArgument() { + assertThatCode(() -> this.provider.setConditionValidators(null)) + .isInstanceOf(IllegalArgumentException.class); + + assertThatCode(() -> this.provider.setConditionValidators(Collections.emptyList())) + .isInstanceOf(IllegalArgumentException.class); + } + + private T build(QName qName) { + return (T) getBuilderFactory().getBuilder(qName).buildObject(qName); + } + private String serialize(XMLObject object) { try { Marshaller marshaller = getMarshallerFactory().getMarshaller(object);