From 3694485056a19a40bfdef92852ed18acac9fb144 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 18 Aug 2020 17:08:37 -0600 Subject: [PATCH] Polish SAML 2.0 Default Assertion Validator In several cases, taking a pre-set ValidationContext is not sufficient. For example, the recipient is calculated via the RelyingPartyRegistration that's currently in the context of the request. Instead, then, createDefaultAssertionValidator was broken up into two different methods: One that takes no parameters and assumes the class's default ValidationContext, and another that takes a converter to derive the ValidationContext from the incoming authentication token. Issue gh-8970 --- .../OpenSamlAuthenticationProvider.java | 115 ++++++++++++------ .../OpenSamlAuthenticationProviderTests.java | 47 +++---- .../authentication/TestOpenSamlObjects.java | 11 ++ 3 files changed, 117 insertions(+), 56 deletions(-) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index f23c644fbb..11fa522588 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -28,6 +28,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Consumer; import javax.annotation.Nonnull; import javax.xml.namespace.QName; @@ -196,10 +197,23 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion))); }; - private Converter assertionValidator = assertionToken -> { - ValidationContext context = createValidationContext(assertionToken); - return createDefaultAssertionValidator(context).convert(assertionToken); - }; + private Converter assertionSignatureValidator = + createDefaultAssertionValidator(INVALID_SIGNATURE, + assertionToken -> { + SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token); + return SAML20AssertionValidators.createSignatureValidator(engine); + }, + assertionToken -> + new ValidationContext(Collections.singletonMap(SIGNATURE_REQUIRED, false)) + ); + + private Converter assertionValidator = + createDefaultAssertionValidator(INVALID_ASSERTION, + assertionToken -> SAML20AssertionValidators.attributeValidator, + assertionToken -> createValidationContext( + assertionToken, + params -> params.put(CLOCK_SKEW, this.responseTimeValidationSkew.toMillis()) + )); private Converter signatureTrustEngineConverter = new SignatureTrustEngineConverter(); @@ -220,34 +234,40 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML 2.0 Response. * * You can still invoke the default validator by delgating to - * {@link #createDefaultAssertionValidator(ValidationContext)}, like so: + * {@link #createDefaultAssertionValidator}, like so: * *
 	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
 	 *  provider.setAssertionValidator(assertionToken -> {
-	 *		ValidationContext context = // ... build using authentication token
-	 *		Saml2ResponseValidatorResult result = createDefaultAssertionValidator(context)
+	 *		Saml2ResponseValidatorResult result = createDefaultAssertionValidator()
 	 *			.convert(assertionToken)
-	 *		return result.concat(myCustomValiator.convert(assertionToken));
+	 *		return result.concat(myCustomValidator.convert(assertionToken));
 	 *  });
 	 * 
* - * Consider taking a look at {@link #createValidationContext(AssertionToken)} to see how it - * constructs a {@link ValidationContext}. - * * You can also use this method to configure the provider to use a different * {@link ValidationContext} from the default, like so: * *
 	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
-	 *	ValidationContext context = // ...
-	 *	provider.setAssertionValidator(createDefaultAssertionValidator(context));
+	 *	provider.setAssertionValidator(
+	 *		createDefaultAssertionValidator(assertionToken -> {
+	 *			Map<String, Object> params = new HashMap<>();
+	 *			params.put(CLOCK_SKEW, 2 * 60 * 1000);
+	 *			// other parameters
+	 *			return new ValidationContext(params);
+	 *		}));
 	 * 
* + * Consider taking a look at {@link #createValidationContext} to see how it + * constructs a {@link ValidationContext}. + * * It is not necessary to delegate to the default validator. You can safely replace it * entirely with your own. Note that signature verification is performed as a separate * step from this validator. * + * This method takes precedence over {@link #setResponseTimeValidationSkew}. + * * @param assertionValidator * @since 5.4 */ @@ -314,11 +334,45 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi * Sets the duration for how much time skew an assertion may tolerate during * timestamp, NotOnOrBefore and NotOnOrAfter, validation. * @param responseTimeValidationSkew duration for skew tolerance + * @deprecated Use {@link #setAssertionValidator(Converter)} instead */ public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) { this.responseTimeValidationSkew = responseTimeValidationSkew; } + + /** + * Construct a default strategy for validating each SAML 2.0 Assertion and + * associated {@link Authentication} token + * + * @return the default assertion validator strategy + * @since 5.4 + */ + public static Converter + createDefaultAssertionValidator() { + + return createDefaultAssertionValidator(INVALID_ASSERTION, + assertionToken -> SAML20AssertionValidators.attributeValidator, + assertionToken -> createValidationContext(assertionToken, params -> {})); + } + + /** + * Construct a default strategy for validating each SAML 2.0 Assertion and + * associated {@link Authentication} token + * + * @return the default assertion validator strategy + * @param contextConverter the conversion strategy to use to generate a {@link ValidationContext} + * for each assertion being validated + * @since 5.4 + */ + public static Converter + createDefaultAssertionValidator(Converter contextConverter) { + + return createDefaultAssertionValidator(INVALID_ASSERTION, + assertionToken -> SAML20AssertionValidators.attributeValidator, + contextConverter); + } + /** * Construct a default strategy for converting a SAML 2.0 Response and {@link Authentication} * token into a {@link Saml2Authentication} @@ -501,19 +555,13 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi logger.debug("Validating " + assertions.size() + " assertions"); } - ValidationContext signatureContext = new ValidationContext - (Collections.singletonMap(SIGNATURE_REQUIRED, false)); // check already performed - SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(token); - Converter signatureValidator = - createDefaultAssertionValidator(INVALID_SIGNATURE, - SAML20AssertionValidators.createSignatureValidator(engine), signatureContext); for (Assertion assertion : assertions) { if (logger.isTraceEnabled()) { logger.trace("Validating assertion " + assertion.getID()); } AssertionToken assertionToken = new AssertionToken(assertion, token); result = result - .concat(signatureValidator.convert(assertionToken)) + .concat(this.assertionSignatureValidator.convert(assertionToken)) .concat(this.assertionValidator.convert(assertionToken)); } @@ -613,18 +661,15 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } } - public static Converter - createDefaultAssertionValidator(ValidationContext context) { - - return createDefaultAssertionValidator(INVALID_ASSERTION, - SAML20AssertionValidators.createAttributeValidator(), context); - } - - private static Converter - createDefaultAssertionValidator(String errorCode, SAML20AssertionValidator validator, ValidationContext context) { + private static Converter createDefaultAssertionValidator( + String errorCode, + Converter validatorConverter, + Converter contextConverter) { return assertionToken -> { Assertion assertion = assertionToken.assertion; + SAML20AssertionValidator validator = validatorConverter.convert(assertionToken); + ValidationContext context = contextConverter.convert(assertionToken); try { ValidationResult result = validator.validate(assertion, context); if (result == ValidationResult.VALID) { @@ -643,13 +688,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi }; } - private ValidationContext createValidationContext(AssertionToken assertionToken) { + private static ValidationContext createValidationContext( + AssertionToken assertionToken, Consumer> paramsConsumer) { String audience = assertionToken.token.getRelyingPartyRegistration().getEntityId(); String recipient = assertionToken.token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); Map params = new HashMap<>(); - params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis()); params.put(COND_VALID_AUDIENCES, singleton(audience)); params.put(SC_VALID_RECIPIENTS, singleton(recipient)); + paramsConsumer.accept(params); return new ValidationContext(params); } @@ -687,15 +733,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi }); } - static SAML20AssertionValidator createAttributeValidator() { - return new SAML20AssertionValidator(conditions, subjects, statements, null, null) { + private static final SAML20AssertionValidator attributeValidator = + new SAML20AssertionValidator(conditions, subjects, statements, null, null) { @Nonnull @Override protected ValidationResult validateSignature(Assertion token, ValidationContext context) { return ValidationResult.VALID; } }; - } static SAML20AssertionValidator createSignatureValidator(SignatureTrustEngine engine) { return new SAML20AssertionValidator(new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), @@ -792,7 +837,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private final Saml2AuthenticationToken token; private final Assertion assertion; - private AssertionToken(Assertion assertion, Saml2AuthenticationToken token) { + AssertionToken(Assertion assertion, Saml2AuthenticationToken token) { this.token = token; this.assertion = assertion; } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index ab60156562..ecf2f785f7 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -19,7 +19,6 @@ package org.springframework.security.saml2.provider.service.authentication; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; -import java.io.StringReader; import java.time.Instant; import java.util.Arrays; import java.util.Collections; @@ -28,8 +27,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import javax.xml.namespace.QName; -import javax.xml.parsers.DocumentBuilder; -import javax.xml.parsers.DocumentBuilderFactory; import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.hamcrest.BaseMatcher; @@ -51,9 +48,7 @@ import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; -import org.w3c.dom.Document; import org.w3c.dom.Element; -import org.xml.sax.InputSource; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.Authentication; @@ -74,7 +69,6 @@ import static org.mockito.Mockito.when; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory; import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; -import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED; import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION; import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE; import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success; @@ -350,14 +344,23 @@ public class OpenSamlAuthenticationProviderTests { objectOutputStream.flush(); } + @Test + public void createDefaultAssertionValidatorWhenAssertionThenValidates() { + Response response = signedResponseWithOneAssertion(); + Assertion assertion = response.getAssertions().get(0); + OpenSamlAuthenticationProvider.AssertionToken assertionToken = + new OpenSamlAuthenticationProvider.AssertionToken(assertion, token()); + assertThat( + createDefaultAssertionValidator().convert(assertionToken) + .hasErrors()).isFalse(); + } + @Test public void authenticateWhenDelegatingToDefaultAssertionValidatorThenUses() { OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(assertionToken -> { - ValidationContext context = new ValidationContext(); - return createDefaultAssertionValidator(context).convert(assertionToken) - .concat(new Saml2Error("wrong error", "wrong error")); - }); + provider.setAssertionValidator(assertionToken -> + createDefaultAssertionValidator(token -> new ValidationContext()).convert(assertionToken) + .concat(new Saml2Error("wrong error", "wrong error"))); Response response = response(); Assertion assertion = assertion(); OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME); @@ -375,12 +378,9 @@ public class OpenSamlAuthenticationProviderTests { Converter validator = mock(Converter.class); OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(assertionToken -> { - ValidationContext context = new ValidationContext( - Collections.singletonMap(SC_VALID_RECIPIENTS, singleton(DESTINATION))); - return createDefaultAssertionValidator(context).convert(assertionToken) - .concat(validator.convert(assertionToken)); - }); + provider.setAssertionValidator(assertionToken -> + createDefaultAssertionValidator().convert(assertionToken) + .concat(validator.convert(assertionToken))); Response response = response(); Assertion assertion = assertion(); response.getAssertions().add(assertion); @@ -410,18 +410,19 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenValidationContextCustomizedThenUsers() { Map parameters = new HashMap<>(); - parameters.put(SC_VALID_RECIPIENTS, singleton(DESTINATION)); - parameters.put(SIGNATURE_REQUIRED, false); + parameters.put(SC_VALID_RECIPIENTS, singleton("blah")); ValidationContext context = mock(ValidationContext.class); when(context.getStaticParameters()).thenReturn(parameters); OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(assertionToken -> createDefaultAssertionValidator(context).convert(assertionToken)); + provider.setAssertionValidator(createDefaultAssertionValidator(assertionToken -> context)); Response response = response(); Assertion assertion = assertion(); response.getAssertions().add(assertion); signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); - provider.authenticate(token); + assertThatThrownBy(() -> provider.authenticate(token)) + .isInstanceOf(Saml2AuthenticationException.class) + .hasMessageContaining("Invalid assertion"); verify(context, atLeastOnce()).getStaticParameters(); } @@ -506,6 +507,10 @@ public class OpenSamlAuthenticationProviderTests { }; } + private Saml2AuthenticationToken token() { + return token(response(), relyingPartyVerifyingCredential()); + } + private Saml2AuthenticationToken token(Response response, Saml2X509Credential... credentials) { String payload = serialize(response); return token(payload, credentials); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index 1df38bd978..3032ae3af6 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -19,7 +19,10 @@ package org.springframework.security.saml2.provider.service.authentication; import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.UUID; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; @@ -45,6 +48,7 @@ import org.opensaml.core.xml.schema.impl.XSStringBuilder; import org.opensaml.core.xml.schema.impl.XSURIBuilder; import org.opensaml.saml.common.SAMLVersion; import org.opensaml.saml.common.SignableSAMLObject; +import org.opensaml.saml.common.assertion.ValidationContext; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; @@ -79,6 +83,7 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2X509Credential; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; +import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential; final class TestOpenSamlObjects { @@ -371,6 +376,12 @@ final class TestOpenSamlObjects { return attributeStatements; } + static ValidationContext validationContext() { + Map params = new HashMap<>(); + params.put(SC_VALID_RECIPIENTS, Collections.singleton(DESTINATION)); + return new ValidationContext(params); + } + static T build(QName qName) { return (T) getBuilderFactory().getBuilder(qName).buildObject(qName); }