Polish SAML 2.0 Default Assertion Validator

In several cases, taking a pre-set ValidationContext is not sufficient.
For example, the recipient is calculated via the
RelyingPartyRegistration that's currently in the context of the
request.

Instead, then, createDefaultAssertionValidator was broken up into two
different methods: One that takes no parameters and assumes the class's
default ValidationContext, and another that takes a converter to derive
the ValidationContext from the incoming authentication token.

Issue gh-8970
This commit is contained in:
Josh Cummings 2020-08-18 17:08:37 -06:00
parent da7477cd41
commit 3694485056
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
3 changed files with 117 additions and 56 deletions

View File

@ -28,6 +28,7 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import javax.annotation.Nonnull;
import javax.xml.namespace.QName;
@ -196,10 +197,23 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
};
private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator = assertionToken -> {
ValidationContext context = createValidationContext(assertionToken);
return createDefaultAssertionValidator(context).convert(assertionToken);
};
private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionSignatureValidator =
createDefaultAssertionValidator(INVALID_SIGNATURE,
assertionToken -> {
SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token);
return SAML20AssertionValidators.createSignatureValidator(engine);
},
assertionToken ->
new ValidationContext(Collections.singletonMap(SIGNATURE_REQUIRED, false))
);
private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator =
createDefaultAssertionValidator(INVALID_ASSERTION,
assertionToken -> SAML20AssertionValidators.attributeValidator,
assertionToken -> createValidationContext(
assertionToken,
params -> params.put(CLOCK_SKEW, this.responseTimeValidationSkew.toMillis())
));
private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter =
new SignatureTrustEngineConverter();
@ -220,34 +234,40 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
* Set the {@link Converter} to use for validating each {@link Assertion} in the SAML 2.0 Response.
*
* You can still invoke the default validator by delgating to
* {@link #createDefaultAssertionValidator(ValidationContext)}, like so:
* {@link #createDefaultAssertionValidator}, like so:
*
* <pre>
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* provider.setAssertionValidator(assertionToken -> {
* ValidationContext context = // ... build using authentication token
* Saml2ResponseValidatorResult result = createDefaultAssertionValidator(context)
* Saml2ResponseValidatorResult result = createDefaultAssertionValidator()
* .convert(assertionToken)
* return result.concat(myCustomValiator.convert(assertionToken));
* return result.concat(myCustomValidator.convert(assertionToken));
* });
* </pre>
*
* Consider taking a look at {@link #createValidationContext(AssertionToken)} to see how it
* constructs a {@link ValidationContext}.
*
* You can also use this method to configure the provider to use a different
* {@link ValidationContext} from the default, like so:
*
* <pre>
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* ValidationContext context = // ...
* provider.setAssertionValidator(createDefaultAssertionValidator(context));
* provider.setAssertionValidator(
* createDefaultAssertionValidator(assertionToken -> {
* Map&lt;String, Object&gt; params = new HashMap&lt;&gt;();
* params.put(CLOCK_SKEW, 2 * 60 * 1000);
* // other parameters
* return new ValidationContext(params);
* }));
* </pre>
*
* Consider taking a look at {@link #createValidationContext} to see how it
* constructs a {@link ValidationContext}.
*
* It is not necessary to delegate to the default validator. You can safely replace it
* entirely with your own. Note that signature verification is performed as a separate
* step from this validator.
*
* This method takes precedence over {@link #setResponseTimeValidationSkew}.
*
* @param assertionValidator
* @since 5.4
*/
@ -314,11 +334,45 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
* Sets the duration for how much time skew an assertion may tolerate during
* timestamp, NotOnOrBefore and NotOnOrAfter, validation.
* @param responseTimeValidationSkew duration for skew tolerance
* @deprecated Use {@link #setAssertionValidator(Converter)} instead
*/
public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) {
this.responseTimeValidationSkew = responseTimeValidationSkew;
}
/**
* Construct a default strategy for validating each SAML 2.0 Assertion and
* associated {@link Authentication} token
*
* @return the default assertion validator strategy
* @since 5.4
*/
public static Converter<AssertionToken, Saml2ResponseValidatorResult>
createDefaultAssertionValidator() {
return createDefaultAssertionValidator(INVALID_ASSERTION,
assertionToken -> SAML20AssertionValidators.attributeValidator,
assertionToken -> createValidationContext(assertionToken, params -> {}));
}
/**
* Construct a default strategy for validating each SAML 2.0 Assertion and
* associated {@link Authentication} token
*
* @return the default assertion validator strategy
* @param contextConverter the conversion strategy to use to generate a {@link ValidationContext}
* for each assertion being validated
* @since 5.4
*/
public static Converter<AssertionToken, Saml2ResponseValidatorResult>
createDefaultAssertionValidator(Converter<AssertionToken, ValidationContext> contextConverter) {
return createDefaultAssertionValidator(INVALID_ASSERTION,
assertionToken -> SAML20AssertionValidators.attributeValidator,
contextConverter);
}
/**
* Construct a default strategy for converting a SAML 2.0 Response and {@link Authentication}
* token into a {@link Saml2Authentication}
@ -501,19 +555,13 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
logger.debug("Validating " + assertions.size() + " assertions");
}
ValidationContext signatureContext = new ValidationContext
(Collections.singletonMap(SIGNATURE_REQUIRED, false)); // check already performed
SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(token);
Converter<AssertionToken, Saml2ResponseValidatorResult> signatureValidator =
createDefaultAssertionValidator(INVALID_SIGNATURE,
SAML20AssertionValidators.createSignatureValidator(engine), signatureContext);
for (Assertion assertion : assertions) {
if (logger.isTraceEnabled()) {
logger.trace("Validating assertion " + assertion.getID());
}
AssertionToken assertionToken = new AssertionToken(assertion, token);
result = result
.concat(signatureValidator.convert(assertionToken))
.concat(this.assertionSignatureValidator.convert(assertionToken))
.concat(this.assertionValidator.convert(assertionToken));
}
@ -613,18 +661,15 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
}
}
public static Converter<AssertionToken, Saml2ResponseValidatorResult>
createDefaultAssertionValidator(ValidationContext context) {
return createDefaultAssertionValidator(INVALID_ASSERTION,
SAML20AssertionValidators.createAttributeValidator(), context);
}
private static Converter<AssertionToken, Saml2ResponseValidatorResult>
createDefaultAssertionValidator(String errorCode, SAML20AssertionValidator validator, ValidationContext context) {
private static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator(
String errorCode,
Converter<AssertionToken, SAML20AssertionValidator> validatorConverter,
Converter<AssertionToken, ValidationContext> contextConverter) {
return assertionToken -> {
Assertion assertion = assertionToken.assertion;
SAML20AssertionValidator validator = validatorConverter.convert(assertionToken);
ValidationContext context = contextConverter.convert(assertionToken);
try {
ValidationResult result = validator.validate(assertion, context);
if (result == ValidationResult.VALID) {
@ -643,13 +688,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
};
}
private ValidationContext createValidationContext(AssertionToken assertionToken) {
private static ValidationContext createValidationContext(
AssertionToken assertionToken, Consumer<Map<String, Object>> paramsConsumer) {
String audience = assertionToken.token.getRelyingPartyRegistration().getEntityId();
String recipient = assertionToken.token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
Map<String, Object> params = new HashMap<>();
params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
params.put(COND_VALID_AUDIENCES, singleton(audience));
params.put(SC_VALID_RECIPIENTS, singleton(recipient));
paramsConsumer.accept(params);
return new ValidationContext(params);
}
@ -687,15 +733,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
});
}
static SAML20AssertionValidator createAttributeValidator() {
return new SAML20AssertionValidator(conditions, subjects, statements, null, null) {
private static final SAML20AssertionValidator attributeValidator =
new SAML20AssertionValidator(conditions, subjects, statements, null, null) {
@Nonnull
@Override
protected ValidationResult validateSignature(Assertion token, ValidationContext context) {
return ValidationResult.VALID;
}
};
}
static SAML20AssertionValidator createSignatureValidator(SignatureTrustEngine engine) {
return new SAML20AssertionValidator(new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
@ -792,7 +837,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private final Saml2AuthenticationToken token;
private final Assertion assertion;
private AssertionToken(Assertion assertion, Saml2AuthenticationToken token) {
AssertionToken(Assertion assertion, Saml2AuthenticationToken token) {
this.token = token;
this.assertion = assertion;
}

View File

@ -19,7 +19,6 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.StringReader;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
@ -28,8 +27,6 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.xml.namespace.QName;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.hamcrest.BaseMatcher;
@ -51,9 +48,7 @@ import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.OneTimeUse;
import org.opensaml.saml.saml2.core.Response;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.xml.sax.InputSource;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.Authentication;
@ -74,7 +69,6 @@ import static org.mockito.Mockito.when;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE;
import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success;
@ -350,14 +344,23 @@ public class OpenSamlAuthenticationProviderTests {
objectOutputStream.flush();
}
@Test
public void createDefaultAssertionValidatorWhenAssertionThenValidates() {
Response response = signedResponseWithOneAssertion();
Assertion assertion = response.getAssertions().get(0);
OpenSamlAuthenticationProvider.AssertionToken assertionToken =
new OpenSamlAuthenticationProvider.AssertionToken(assertion, token());
assertThat(
createDefaultAssertionValidator().convert(assertionToken)
.hasErrors()).isFalse();
}
@Test
public void authenticateWhenDelegatingToDefaultAssertionValidatorThenUses() {
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setAssertionValidator(assertionToken -> {
ValidationContext context = new ValidationContext();
return createDefaultAssertionValidator(context).convert(assertionToken)
.concat(new Saml2Error("wrong error", "wrong error"));
});
provider.setAssertionValidator(assertionToken ->
createDefaultAssertionValidator(token -> new ValidationContext()).convert(assertionToken)
.concat(new Saml2Error("wrong error", "wrong error")));
Response response = response();
Assertion assertion = assertion();
OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME);
@ -375,12 +378,9 @@ public class OpenSamlAuthenticationProviderTests {
Converter<OpenSamlAuthenticationProvider.AssertionToken, Saml2ResponseValidatorResult> validator =
mock(Converter.class);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setAssertionValidator(assertionToken -> {
ValidationContext context = new ValidationContext(
Collections.singletonMap(SC_VALID_RECIPIENTS, singleton(DESTINATION)));
return createDefaultAssertionValidator(context).convert(assertionToken)
.concat(validator.convert(assertionToken));
});
provider.setAssertionValidator(assertionToken ->
createDefaultAssertionValidator().convert(assertionToken)
.concat(validator.convert(assertionToken)));
Response response = response();
Assertion assertion = assertion();
response.getAssertions().add(assertion);
@ -410,18 +410,19 @@ public class OpenSamlAuthenticationProviderTests {
@Test
public void authenticateWhenValidationContextCustomizedThenUsers() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(SC_VALID_RECIPIENTS, singleton(DESTINATION));
parameters.put(SIGNATURE_REQUIRED, false);
parameters.put(SC_VALID_RECIPIENTS, singleton("blah"));
ValidationContext context = mock(ValidationContext.class);
when(context.getStaticParameters()).thenReturn(parameters);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setAssertionValidator(assertionToken -> createDefaultAssertionValidator(context).convert(assertionToken));
provider.setAssertionValidator(createDefaultAssertionValidator(assertionToken -> context));
Response response = response();
Assertion assertion = assertion();
response.getAssertions().add(assertion);
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
provider.authenticate(token);
assertThatThrownBy(() -> provider.authenticate(token))
.isInstanceOf(Saml2AuthenticationException.class)
.hasMessageContaining("Invalid assertion");
verify(context, atLeastOnce()).getStaticParameters();
}
@ -506,6 +507,10 @@ public class OpenSamlAuthenticationProviderTests {
};
}
private Saml2AuthenticationToken token() {
return token(response(), relyingPartyVerifyingCredential());
}
private Saml2AuthenticationToken token(Response response, Saml2X509Credential... credentials) {
String payload = serialize(response);
return token(payload, credentials);

View File

@ -19,7 +19,10 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
@ -45,6 +48,7 @@ import org.opensaml.core.xml.schema.impl.XSStringBuilder;
import org.opensaml.core.xml.schema.impl.XSURIBuilder;
import org.opensaml.saml.common.SAMLVersion;
import org.opensaml.saml.common.SignableSAMLObject;
import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
@ -79,6 +83,7 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2X509Credential;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential;
final class TestOpenSamlObjects {
@ -371,6 +376,12 @@ final class TestOpenSamlObjects {
return attributeStatements;
}
static ValidationContext validationContext() {
Map<String, Object> params = new HashMap<>();
params.put(SC_VALID_RECIPIENTS, Collections.singleton(DESTINATION));
return new ValidationContext(params);
}
static <T extends XMLObject> T build(QName qName) {
return (T) getBuilderFactory().getBuilder(qName).buildObject(qName);
}