parent
d9d8253603
commit
a402c3884a
|
@ -21,13 +21,13 @@ import java.time.Duration;
|
|||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
import javax.annotation.Nonnull;
|
||||
|
||||
|
@ -193,10 +193,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
|||
|
||||
private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter =
|
||||
new SignatureTrustEngineConverter();
|
||||
private Converter<Saml2AuthenticationToken, SAML20AssertionValidator> assertionValidatorConverter =
|
||||
private Converter<Tuple, SAML20AssertionValidator> assertionValidatorConverter =
|
||||
new SAML20AssertionValidatorConverter();
|
||||
private Converter<Saml2AuthenticationToken, ValidationContext> validationContextConverter =
|
||||
new ValidationContextConverter(params -> {});
|
||||
private Collection<ConditionValidator> conditionValidators =
|
||||
Collections.singleton(new AudienceRestrictionConditionValidator());
|
||||
private Converter<Tuple, ValidationContext> validationContextConverter =
|
||||
new ValidationContextConverter();
|
||||
private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
|
||||
|
||||
/**
|
||||
|
@ -209,6 +211,33 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
|||
this.parserPool = this.registry.getParserPool();
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the the collection of {@link ConditionValidator}s used when validating an assertion.
|
||||
*
|
||||
* @param conditionValidators the collection of validators to use
|
||||
* @since 5.4
|
||||
*/
|
||||
public void setConditionValidators(
|
||||
Collection<ConditionValidator> conditionValidators) {
|
||||
|
||||
Assert.notEmpty(conditionValidators, "conditionValidators cannot be empty");
|
||||
this.conditionValidators = conditionValidators;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the strategy for retrieving the {@link ValidationContext} used when
|
||||
* validating an assertion.
|
||||
*
|
||||
* @param validationContextConverter the strategy to use
|
||||
* @since 5.4
|
||||
*/
|
||||
public void setValidationContextConverter(
|
||||
Converter<Tuple, ValidationContext> validationContextConverter) {
|
||||
|
||||
Assert.notNull(validationContextConverter, "validationContextConverter cannot be empty");
|
||||
this.validationContextConverter = validationContextConverter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link Converter} used for extracting assertion attributes that
|
||||
* can be mapped to authorities.
|
||||
|
@ -238,8 +267,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
|||
*/
|
||||
public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) {
|
||||
this.responseTimeValidationSkew = responseTimeValidationSkew;
|
||||
this.validationContextConverter = new ValidationContextConverter(
|
||||
params -> params.put(CLOCK_SKEW, responseTimeValidationSkew.toMillis()));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -303,7 +330,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
|||
throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " +
|
||||
"Please either sign the response or all of the assertions.");
|
||||
}
|
||||
validationExceptions.putAll(validateAssertions(token, assertions));
|
||||
validationExceptions.putAll(validateAssertions(token, response));
|
||||
|
||||
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
|
||||
NameID nameId = decryptPrincipal(decrypter, firstAssertion);
|
||||
|
@ -392,7 +419,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
|||
}
|
||||
|
||||
private Map<String, Saml2AuthenticationException> validateAssertions
|
||||
(Saml2AuthenticationToken token, List<Assertion> assertions) {
|
||||
(Saml2AuthenticationToken token, Response response) {
|
||||
List<Assertion> assertions = response.getAssertions();
|
||||
if (assertions.isEmpty()) {
|
||||
throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
|
||||
}
|
||||
|
@ -401,14 +429,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
|||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Validating " + assertions.size() + " assertions");
|
||||
}
|
||||
|
||||
Tuple tuple = new Tuple(token, response);
|
||||
SAML20AssertionValidator validator = this.assertionValidatorConverter.convert(tuple);
|
||||
ValidationContext context = this.validationContextConverter.convert(tuple);
|
||||
for (Assertion assertion : assertions) {
|
||||
if (logger.isTraceEnabled()) {
|
||||
logger.trace("Validating assertion " + assertion.getID());
|
||||
}
|
||||
try {
|
||||
ValidationContext context = this.validationContextConverter.convert(token);
|
||||
ValidationResult result = this.assertionValidatorConverter.convert(token).validate(assertion, context);
|
||||
if (result != ValidationResult.VALID) {
|
||||
if (validator.validate(assertion, context) != ValidationResult.VALID) {
|
||||
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
|
||||
assertion.getID(), ((Response) assertion.getParent()).getID(),
|
||||
context.getValidationFailureMessage());
|
||||
|
@ -512,6 +542,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
|||
}
|
||||
|
||||
private static class SignatureTrustEngineConverter implements Converter<Saml2AuthenticationToken, SignatureTrustEngine> {
|
||||
|
||||
@Override
|
||||
public SignatureTrustEngine convert(Saml2AuthenticationToken token) {
|
||||
Set<Credential> credentials = new HashSet<>();
|
||||
|
@ -530,35 +561,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
|||
}
|
||||
}
|
||||
|
||||
private static class ValidationContextConverter implements Converter<Saml2AuthenticationToken, ValidationContext> {
|
||||
Consumer<Map<String, Object>> validationContextParametersConverter;
|
||||
|
||||
ValidationContextConverter(Consumer<Map<String, Object>> validationContextParametersConverter) {
|
||||
this.validationContextParametersConverter = validationContextParametersConverter;
|
||||
}
|
||||
private class ValidationContextConverter implements Converter<Tuple, ValidationContext> {
|
||||
|
||||
@Override
|
||||
public ValidationContext convert(Saml2AuthenticationToken token) {
|
||||
String audience = token.getRelyingPartyRegistration().getEntityId();
|
||||
String recipient = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
|
||||
public ValidationContext convert(Tuple tuple) {
|
||||
String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId();
|
||||
String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put(CLOCK_SKEW, Duration.ofMinutes(5).toMillis());
|
||||
params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
|
||||
params.put(COND_VALID_AUDIENCES, singleton(audience));
|
||||
params.put(SC_VALID_RECIPIENTS, singleton(recipient));
|
||||
params.put(SIGNATURE_REQUIRED, false); // this verification is performed earlier
|
||||
this.validationContextParametersConverter.accept(params);
|
||||
return new ValidationContext(params);
|
||||
}
|
||||
}
|
||||
|
||||
private class SAML20AssertionValidatorConverter implements Converter<Saml2AuthenticationToken, SAML20AssertionValidator> {
|
||||
private final Collection<ConditionValidator> conditions = new ArrayList<>();
|
||||
private class SAML20AssertionValidatorConverter implements Converter<Tuple, SAML20AssertionValidator> {
|
||||
private final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
|
||||
private final Collection<StatementValidator> statements = new ArrayList<>();
|
||||
private final SignaturePrevalidator validator = new SAMLSignatureProfileValidator();
|
||||
|
||||
SAML20AssertionValidatorConverter() {
|
||||
this.conditions.add(new AudienceRestrictionConditionValidator());
|
||||
this.subjects.add(new BearerSubjectConfirmationValidator() {
|
||||
@Nonnull
|
||||
@Override
|
||||
|
@ -571,9 +594,11 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
|||
}
|
||||
|
||||
@Override
|
||||
public SAML20AssertionValidator convert(Saml2AuthenticationToken token) {
|
||||
return new SAML20AssertionValidator(this.conditions, this.subjects, this.statements,
|
||||
OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(token),
|
||||
public SAML20AssertionValidator convert(Tuple tuple) {
|
||||
Collection<ConditionValidator> conditions =
|
||||
OpenSamlAuthenticationProvider.this.conditionValidators;
|
||||
return new SAML20AssertionValidator(conditions, this.subjects, this.statements,
|
||||
OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(tuple.authentication),
|
||||
this.validator);
|
||||
}
|
||||
}
|
||||
|
@ -616,4 +641,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
|||
|
||||
return new Saml2AuthenticationException(validationError(code, description), cause);
|
||||
}
|
||||
|
||||
/**
|
||||
* A tuple containing the authentication token and the associated OpenSAML {@link Response}.
|
||||
*
|
||||
* @since 5.4
|
||||
*/
|
||||
public static class Tuple {
|
||||
private final Saml2AuthenticationToken authentication;
|
||||
private final Response response;
|
||||
|
||||
private Tuple(Saml2AuthenticationToken authentication, Response response) {
|
||||
this.authentication = authentication;
|
||||
this.response = response;
|
||||
}
|
||||
|
||||
public Saml2AuthenticationToken getAuthentication() {
|
||||
return this.authentication;
|
||||
}
|
||||
|
||||
public Response getResponse() {
|
||||
return this.response;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,9 +23,11 @@ import java.io.StringReader;
|
|||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import javax.xml.namespace.QName;
|
||||
import javax.xml.parsers.DocumentBuilder;
|
||||
import javax.xml.parsers.DocumentBuilderFactory;
|
||||
|
||||
|
@ -42,12 +44,17 @@ import org.opensaml.core.xml.XMLObject;
|
|||
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
|
||||
import org.opensaml.core.xml.io.Marshaller;
|
||||
import org.opensaml.core.xml.io.MarshallingException;
|
||||
import org.opensaml.saml.common.assertion.ValidationContext;
|
||||
import org.opensaml.saml.common.assertion.ValidationResult;
|
||||
import org.opensaml.saml.saml2.assertion.impl.OneTimeUseConditionValidator;
|
||||
import org.opensaml.saml.saml2.core.Assertion;
|
||||
import org.opensaml.saml.saml2.core.AttributeStatement;
|
||||
import org.opensaml.saml.saml2.core.AttributeValue;
|
||||
import org.opensaml.saml.saml2.core.Condition;
|
||||
import org.opensaml.saml.saml2.core.EncryptedAssertion;
|
||||
import org.opensaml.saml.saml2.core.EncryptedID;
|
||||
import org.opensaml.saml.saml2.core.NameID;
|
||||
import org.opensaml.saml.saml2.core.OneTimeUse;
|
||||
import org.opensaml.saml.saml2.core.Response;
|
||||
import org.w3c.dom.Document;
|
||||
import org.w3c.dom.Element;
|
||||
|
@ -57,7 +64,9 @@ import org.springframework.security.core.Authentication;
|
|||
import org.springframework.security.saml2.Saml2Exception;
|
||||
import org.springframework.security.saml2.credentials.Saml2X509Credential;
|
||||
|
||||
import static java.util.Collections.singleton;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.atLeastOnce;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
@ -65,6 +74,8 @@ import static org.mockito.Mockito.verify;
|
|||
import static org.mockito.Mockito.when;
|
||||
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
|
||||
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
|
||||
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
|
||||
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
|
||||
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
|
||||
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
|
||||
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
|
||||
|
@ -353,6 +364,62 @@ public class OpenSamlAuthenticationProviderTests {
|
|||
objectOutputStream.flush();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authenticateWhenConditionValidatorsCustomizedThenUses() throws Exception {
|
||||
OneTimeUseConditionValidator validator = mock(OneTimeUseConditionValidator.class);
|
||||
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
|
||||
provider.setConditionValidators(Collections.singleton(validator));
|
||||
Response response = response();
|
||||
Assertion assertion = assertion();
|
||||
OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME);
|
||||
assertion.getConditions().getConditions().add(oneTimeUse);
|
||||
response.getAssertions().add(assertion);
|
||||
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
|
||||
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
|
||||
when(validator.getServicedCondition()).thenReturn(OneTimeUse.DEFAULT_ELEMENT_NAME);
|
||||
when(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)))
|
||||
.thenReturn(ValidationResult.VALID);
|
||||
provider.authenticate(token);
|
||||
verify(validator).validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authenticateWhenValidationContextCustomizedThenUsers() {
|
||||
Map<String, Object> parameters = new HashMap<>();
|
||||
parameters.put(SC_VALID_RECIPIENTS, singleton(DESTINATION));
|
||||
parameters.put(SIGNATURE_REQUIRED, false);
|
||||
ValidationContext context = mock(ValidationContext.class);
|
||||
when(context.getStaticParameters()).thenReturn(parameters);
|
||||
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
|
||||
provider.setValidationContextConverter(tuple -> context);
|
||||
Response response = response();
|
||||
Assertion assertion = assertion();
|
||||
response.getAssertions().add(assertion);
|
||||
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
|
||||
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
|
||||
provider.authenticate(token);
|
||||
verify(context, atLeastOnce()).getStaticParameters();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setValidationContextConverterWhenNullThenIllegalArgument() {
|
||||
assertThatCode(() -> this.provider.setValidationContextConverter(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setConditionValidatorsWhenNullOrEmptyThenIllegalArgument() {
|
||||
assertThatCode(() -> this.provider.setConditionValidators(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
|
||||
assertThatCode(() -> this.provider.setConditionValidators(Collections.emptyList()))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
private <T extends XMLObject> T build(QName qName) {
|
||||
return (T) getBuilderFactory().getBuilder(qName).buildObject(qName);
|
||||
}
|
||||
|
||||
private String serialize(XMLObject object) {
|
||||
try {
|
||||
Marshaller marshaller = getMarshallerFactory().getMarshaller(object);
|
||||
|
|
Loading…
Reference in New Issue