Add ConditionValidator Support

Closes gh-8769
This commit is contained in:
Josh Cummings 2020-08-04 07:58:26 -06:00
parent d9d8253603
commit a402c3884a
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 143 additions and 28 deletions

View File

@ -21,13 +21,13 @@ import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
@ -193,10 +193,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter = private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter =
new SignatureTrustEngineConverter(); new SignatureTrustEngineConverter();
private Converter<Saml2AuthenticationToken, SAML20AssertionValidator> assertionValidatorConverter = private Converter<Tuple, SAML20AssertionValidator> assertionValidatorConverter =
new SAML20AssertionValidatorConverter(); new SAML20AssertionValidatorConverter();
private Converter<Saml2AuthenticationToken, ValidationContext> validationContextConverter = private Collection<ConditionValidator> conditionValidators =
new ValidationContextConverter(params -> {}); Collections.singleton(new AudienceRestrictionConditionValidator());
private Converter<Tuple, ValidationContext> validationContextConverter =
new ValidationContextConverter();
private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter(); private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
/** /**
@ -209,6 +211,33 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this.parserPool = this.registry.getParserPool(); this.parserPool = this.registry.getParserPool();
} }
/**
* Set the the collection of {@link ConditionValidator}s used when validating an assertion.
*
* @param conditionValidators the collection of validators to use
* @since 5.4
*/
public void setConditionValidators(
Collection<ConditionValidator> conditionValidators) {
Assert.notEmpty(conditionValidators, "conditionValidators cannot be empty");
this.conditionValidators = conditionValidators;
}
/**
* Set the strategy for retrieving the {@link ValidationContext} used when
* validating an assertion.
*
* @param validationContextConverter the strategy to use
* @since 5.4
*/
public void setValidationContextConverter(
Converter<Tuple, ValidationContext> validationContextConverter) {
Assert.notNull(validationContextConverter, "validationContextConverter cannot be empty");
this.validationContextConverter = validationContextConverter;
}
/** /**
* Sets the {@link Converter} used for extracting assertion attributes that * Sets the {@link Converter} used for extracting assertion attributes that
* can be mapped to authorities. * can be mapped to authorities.
@ -238,8 +267,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
*/ */
public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) { public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) {
this.responseTimeValidationSkew = responseTimeValidationSkew; this.responseTimeValidationSkew = responseTimeValidationSkew;
this.validationContextConverter = new ValidationContextConverter(
params -> params.put(CLOCK_SKEW, responseTimeValidationSkew.toMillis()));
} }
/** /**
@ -303,7 +330,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " + throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " +
"Please either sign the response or all of the assertions."); "Please either sign the response or all of the assertions.");
} }
validationExceptions.putAll(validateAssertions(token, assertions)); validationExceptions.putAll(validateAssertions(token, response));
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions()); Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
NameID nameId = decryptPrincipal(decrypter, firstAssertion); NameID nameId = decryptPrincipal(decrypter, firstAssertion);
@ -392,7 +419,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
private Map<String, Saml2AuthenticationException> validateAssertions private Map<String, Saml2AuthenticationException> validateAssertions
(Saml2AuthenticationToken token, List<Assertion> assertions) { (Saml2AuthenticationToken token, Response response) {
List<Assertion> assertions = response.getAssertions();
if (assertions.isEmpty()) { if (assertions.isEmpty()) {
throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response."); throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
} }
@ -401,14 +429,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Validating " + assertions.size() + " assertions"); logger.debug("Validating " + assertions.size() + " assertions");
} }
Tuple tuple = new Tuple(token, response);
SAML20AssertionValidator validator = this.assertionValidatorConverter.convert(tuple);
ValidationContext context = this.validationContextConverter.convert(tuple);
for (Assertion assertion : assertions) { for (Assertion assertion : assertions) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Validating assertion " + assertion.getID()); logger.trace("Validating assertion " + assertion.getID());
} }
try { try {
ValidationContext context = this.validationContextConverter.convert(token); if (validator.validate(assertion, context) != ValidationResult.VALID) {
ValidationResult result = this.assertionValidatorConverter.convert(token).validate(assertion, context);
if (result != ValidationResult.VALID) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
assertion.getID(), ((Response) assertion.getParent()).getID(), assertion.getID(), ((Response) assertion.getParent()).getID(),
context.getValidationFailureMessage()); context.getValidationFailureMessage());
@ -512,6 +542,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
private static class SignatureTrustEngineConverter implements Converter<Saml2AuthenticationToken, SignatureTrustEngine> { private static class SignatureTrustEngineConverter implements Converter<Saml2AuthenticationToken, SignatureTrustEngine> {
@Override @Override
public SignatureTrustEngine convert(Saml2AuthenticationToken token) { public SignatureTrustEngine convert(Saml2AuthenticationToken token) {
Set<Credential> credentials = new HashSet<>(); Set<Credential> credentials = new HashSet<>();
@ -530,35 +561,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
} }
private static class ValidationContextConverter implements Converter<Saml2AuthenticationToken, ValidationContext> { private class ValidationContextConverter implements Converter<Tuple, ValidationContext> {
Consumer<Map<String, Object>> validationContextParametersConverter;
ValidationContextConverter(Consumer<Map<String, Object>> validationContextParametersConverter) {
this.validationContextParametersConverter = validationContextParametersConverter;
}
@Override @Override
public ValidationContext convert(Saml2AuthenticationToken token) { public ValidationContext convert(Tuple tuple) {
String audience = token.getRelyingPartyRegistration().getEntityId(); String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId();
String recipient = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();
params.put(CLOCK_SKEW, Duration.ofMinutes(5).toMillis()); params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
params.put(COND_VALID_AUDIENCES, singleton(audience)); params.put(COND_VALID_AUDIENCES, singleton(audience));
params.put(SC_VALID_RECIPIENTS, singleton(recipient)); params.put(SC_VALID_RECIPIENTS, singleton(recipient));
params.put(SIGNATURE_REQUIRED, false); // this verification is performed earlier params.put(SIGNATURE_REQUIRED, false); // this verification is performed earlier
this.validationContextParametersConverter.accept(params);
return new ValidationContext(params); return new ValidationContext(params);
} }
} }
private class SAML20AssertionValidatorConverter implements Converter<Saml2AuthenticationToken, SAML20AssertionValidator> { private class SAML20AssertionValidatorConverter implements Converter<Tuple, SAML20AssertionValidator> {
private final Collection<ConditionValidator> conditions = new ArrayList<>();
private final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>(); private final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
private final Collection<StatementValidator> statements = new ArrayList<>(); private final Collection<StatementValidator> statements = new ArrayList<>();
private final SignaturePrevalidator validator = new SAMLSignatureProfileValidator(); private final SignaturePrevalidator validator = new SAMLSignatureProfileValidator();
SAML20AssertionValidatorConverter() { SAML20AssertionValidatorConverter() {
this.conditions.add(new AudienceRestrictionConditionValidator());
this.subjects.add(new BearerSubjectConfirmationValidator() { this.subjects.add(new BearerSubjectConfirmationValidator() {
@Nonnull @Nonnull
@Override @Override
@ -571,9 +594,11 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
@Override @Override
public SAML20AssertionValidator convert(Saml2AuthenticationToken token) { public SAML20AssertionValidator convert(Tuple tuple) {
return new SAML20AssertionValidator(this.conditions, this.subjects, this.statements, Collection<ConditionValidator> conditions =
OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(token), OpenSamlAuthenticationProvider.this.conditionValidators;
return new SAML20AssertionValidator(conditions, this.subjects, this.statements,
OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(tuple.authentication),
this.validator); this.validator);
} }
} }
@ -616,4 +641,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return new Saml2AuthenticationException(validationError(code, description), cause); return new Saml2AuthenticationException(validationError(code, description), cause);
} }
/**
* A tuple containing the authentication token and the associated OpenSAML {@link Response}.
*
* @since 5.4
*/
public static class Tuple {
private final Saml2AuthenticationToken authentication;
private final Response response;
private Tuple(Saml2AuthenticationToken authentication, Response response) {
this.authentication = authentication;
this.response = response;
}
public Saml2AuthenticationToken getAuthentication() {
return this.authentication;
}
public Response getResponse() {
return this.response;
}
}
} }

View File

@ -23,9 +23,11 @@ import java.io.StringReader;
import java.time.Instant; import java.time.Instant;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import javax.xml.namespace.QName;
import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.DocumentBuilderFactory;
@ -42,12 +44,17 @@ import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller; import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.common.assertion.ValidationResult;
import org.opensaml.saml.saml2.assertion.impl.OneTimeUseConditionValidator;
import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AttributeValue; import org.opensaml.saml.saml2.core.AttributeValue;
import org.opensaml.saml.saml2.core.Condition;
import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.OneTimeUse;
import org.opensaml.saml.saml2.core.Response; import org.opensaml.saml.saml2.core.Response;
import org.w3c.dom.Document; import org.w3c.dom.Document;
import org.w3c.dom.Element; import org.w3c.dom.Element;
@ -57,7 +64,9 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.credentials.Saml2X509Credential; import org.springframework.security.saml2.credentials.Saml2X509Credential;
import static java.util.Collections.singleton;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -65,6 +74,8 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
@ -353,6 +364,62 @@ public class OpenSamlAuthenticationProviderTests {
objectOutputStream.flush(); objectOutputStream.flush();
} }
@Test
public void authenticateWhenConditionValidatorsCustomizedThenUses() throws Exception {
OneTimeUseConditionValidator validator = mock(OneTimeUseConditionValidator.class);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setConditionValidators(Collections.singleton(validator));
Response response = response();
Assertion assertion = assertion();
OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME);
assertion.getConditions().getConditions().add(oneTimeUse);
response.getAssertions().add(assertion);
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
when(validator.getServicedCondition()).thenReturn(OneTimeUse.DEFAULT_ELEMENT_NAME);
when(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)))
.thenReturn(ValidationResult.VALID);
provider.authenticate(token);
verify(validator).validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class));
}
@Test
public void authenticateWhenValidationContextCustomizedThenUsers() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(SC_VALID_RECIPIENTS, singleton(DESTINATION));
parameters.put(SIGNATURE_REQUIRED, false);
ValidationContext context = mock(ValidationContext.class);
when(context.getStaticParameters()).thenReturn(parameters);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setValidationContextConverter(tuple -> context);
Response response = response();
Assertion assertion = assertion();
response.getAssertions().add(assertion);
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
provider.authenticate(token);
verify(context, atLeastOnce()).getStaticParameters();
}
@Test
public void setValidationContextConverterWhenNullThenIllegalArgument() {
assertThatCode(() -> this.provider.setValidationContextConverter(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setConditionValidatorsWhenNullOrEmptyThenIllegalArgument() {
assertThatCode(() -> this.provider.setConditionValidators(null))
.isInstanceOf(IllegalArgumentException.class);
assertThatCode(() -> this.provider.setConditionValidators(Collections.emptyList()))
.isInstanceOf(IllegalArgumentException.class);
}
private <T extends XMLObject> T build(QName qName) {
return (T) getBuilderFactory().getBuilder(qName).buildObject(qName);
}
private String serialize(XMLObject object) { private String serialize(XMLObject object) {
try { try {
Marshaller marshaller = getMarshallerFactory().getMarshaller(object); Marshaller marshaller = getMarshallerFactory().getMarshaller(object);