Polish Configurable SAML Decryption Support

- Renamed to setResponseElementsDecrypter and
setAssertionElementsDecrypter to align with ResponseToken and
AssertionToken
- Changed contract of setAssertionElementsDecrypter to use
AssertionToken
- Changed assertions in unit test to use isEqualTo

Issue gh-9044
This commit is contained in:
Josh Cummings 2020-09-30 17:01:23 -06:00
parent 535ae3e27d
commit d0581c9a26
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 295 additions and 229 deletions

View File

@ -88,7 +88,6 @@ import org.opensaml.security.criteria.UsageCriterion;
import org.opensaml.security.x509.BasicX509Credential; import org.opensaml.security.x509.BasicX509Credential;
import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap;
import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver;
import org.opensaml.xmlsec.encryption.support.DecryptionException;
import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver;
import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver;
import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver;
@ -185,59 +184,24 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private Duration responseTimeValidationSkew = Duration.ofMinutes(5); private Duration responseTimeValidationSkew = Duration.ofMinutes(5);
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = ( private Converter<ResponseToken, Saml2ResponseValidatorResult> responseSignatureValidator = createDefaultResponseSignatureValidator();
responseToken) -> {
Response response = responseToken.response;
Saml2AuthenticationToken token = responseToken.token;
Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
String username = assertion.getSubject().getNameID().getValue();
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
return new Saml2Authentication(new DefaultSaml2AuthenticatedPrincipal(username, attributes),
token.getSaml2Response(), this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
};
private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionSignatureValidator = createDefaultAssertionValidator( private Consumer<ResponseToken> responseElementsDecrypter = createDefaultResponseElementsDecrypter();
Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> {
SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token);
return SAML20AssertionValidators.createSignatureValidator(engine);
}, (assertionToken) -> new ValidationContext(
Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false)));
private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator = createDefaultAssertionValidator( private Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator = createDefaultResponseValidator();
Saml2ErrorCodes.INVALID_ASSERTION, (assertionToken) -> SAML20AssertionValidators.attributeValidator,
(assertionToken) -> createValidationContext(assertionToken, (params) -> params private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionSignatureValidator = createDefaultAssertionSignatureValidator();
.put(SAML2AssertionValidationParameters.CLOCK_SKEW, this.responseTimeValidationSkew.toMillis())));
private Consumer<AssertionToken> assertionElementsDecrypter = createDefaultAssertionElementsDecrypter();
private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator = createCompatibleAssertionValidator();
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createCompatibleResponseAuthenticationConverter();
private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter = new SignatureTrustEngineConverter(); private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter = new SignatureTrustEngineConverter();
private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter(); private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
private Consumer<ResponseToken> assertionDecrypter = (responseToken) -> {
List<Assertion> assertions = new ArrayList<>();
for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) {
try {
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
Assertion assertion = decrypter.decrypt(encryptedAssertion);
assertions.add(assertion);
}
catch (DecryptionException ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
}
responseToken.getResponse().getAssertions().addAll(assertions);
};
private Consumer<ResponseToken> principalDecrypter = (responseToken) -> {
try {
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
}
catch (DecryptionException ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
};
/** /**
* Creates an {@link OpenSamlAuthenticationProvider} * Creates an {@link OpenSamlAuthenticationProvider}
*/ */
@ -248,12 +212,60 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this.parserPool = this.registry.getParserPool(); this.parserPool = this.registry.getParserPool();
} }
/**
* Set the {@link Consumer} strategy to use for decrypting elements of a validated
* {@link Response}. The default strategy decrypts all {@link EncryptedAssertion}s
* using OpenSAML's {@link Decrypter}, adding the results to
* {@link Response#getAssertions()}.
*
* You can use this method to configure the {@link Decrypter} instance like so:
*
* <pre>
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* provider.setResponseElementsDecrypter((responseToken) -> {
* DecrypterParameters parameters = new DecrypterParameters();
* // ... set parameters as needed
* Decrypter decrypter = new Decrypter(parameters);
* Response response = responseToken.getResponse();
* EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
* try {
* Assertion assertion = decrypter.decrypt(encrypted);
* response.getAssertions().add(assertion);
* } catch (Exception e) {
* throw new Saml2AuthenticationException(...);
* }
* });
* </pre>
*
* Or, in the event that you have your own custom decryption interface, the same
* pattern applies:
*
* <pre>
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* Converter&lt;EncryptedAssertion, Assertion&gt; myService = ...
* provider.setResponseDecrypter((responseToken) -> {
* Response response = responseToken.getResponse();
* response.getEncryptedAssertions().stream()
* .map(service::decrypt).forEach(response.getAssertions()::add);
* });
* </pre>
*
* This is valuable when using an external service to perform the decryption.
* @param responseElementsDecrypter the {@link Consumer} for decrypting response
* elements
* @since 5.5
*/
public void setResponseElementsDecrypter(Consumer<ResponseToken> responseElementsDecrypter) {
Assert.notNull(responseElementsDecrypter, "responseElementsDecrypter cannot be null");
this.responseElementsDecrypter = responseElementsDecrypter;
}
/** /**
* Set the {@link Converter} to use for validating each {@link Assertion} in the SAML * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML
* 2.0 Response. * 2.0 Response.
* *
* You can still invoke the default validator by delgating to * You can still invoke the default validator by delgating to
* {@link #createDefaultAssertionValidator}, like so: * {@link #createAssertionValidator}, like so:
* *
* <pre> * <pre>
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); * OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
@ -294,6 +306,49 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this.assertionValidator = assertionValidator; this.assertionValidator = assertionValidator;
} }
/**
* Set the {@link Consumer} strategy to use for decrypting elements of a validated
* {@link Assertion}.
*
* You can use this method to configure the {@link Decrypter} used like so:
*
* <pre>
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* provider.setResponseDecrypter((assertionToken) -> {
* DecrypterParameters parameters = new DecrypterParameters();
* // ... set parameters as needed
* Decrypter decrypter = new Decrypter(parameters);
* Assertion assertion = assertionToken.getAssertion();
* EncryptedID encrypted = assertion.getSubject().getEncryptedID();
* try {
* NameID name = decrypter.decrypt(encrypted);
* assertion.getSubject().setNameID(name);
* } catch (Exception e) {
* throw new Saml2AuthenticationException(...);
* }
* });
* </pre>
*
* Or, in the event that you have your own custom interface, the same pattern applies:
*
* <pre>
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* MyDecryptionService myService = ...
* provider.setResponseDecrypter((responseToken) -> {
* Assertion assertion = assertionToken.getAssertion();
* EncryptedID encrypted = assertion.getSubject().getEncryptedID();
* NameID name = myService.decrypt(encrypted);
* assertion.getSubject().setNameID(name);
* });
* </pre>
* @param assertionDecrypter the {@link Consumer} for decrypting assertion elements
* @since 5.5
*/
public void setAssertionElementsDecrypter(Consumer<AssertionToken> assertionDecrypter) {
Assert.notNull(assertionDecrypter, "assertionDecrypter cannot be null");
this.assertionElementsDecrypter = assertionDecrypter;
}
/** /**
* Set the {@link Converter} to use for converting a validated {@link Response} into * Set the {@link Converter} to use for converting a validated {@link Response} into
* an {@link AbstractAuthenticationToken}. * an {@link AbstractAuthenticationToken}.
@ -359,52 +414,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this.responseTimeValidationSkew = responseTimeValidationSkew; this.responseTimeValidationSkew = responseTimeValidationSkew;
} }
/**
* Sets the assertion response custom decrypter.
*
* You can use this method like so:
*
* <pre>
* YourDecrypter decrypter = // ... your custom decrypter
*
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* provider.setAssertionDecrypter((responseToken) -> {
* Response response = responseToken.getResponse();
* EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
* Assertion assertion = decrypter.decrypt(encrypted);
* response.getAssertions().add(assertion);
* });
* </pre>
* @param assertionDecrypter response token consumer
*/
public void setAssertionDecrypter(Consumer<ResponseToken> assertionDecrypter) {
Assert.notNull(assertionDecrypter, "Consumer<ResponseToken> required");
this.assertionDecrypter = assertionDecrypter;
}
/**
* Sets the principal custom decrypter.
*
* You can use this method like so:
*
* <pre>
* YourDecrypter decrypter = // ... your custom decrypter
*
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* provider.setAssertionDecrypter((responseToken) -> {
* Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
* EncryptedID encrypted = assertion.getSubject().getEncryptedID();
* NameID name = decrypter.decrypt(encrypted);
* assertion.getSubject().setNameID(name)
* });
* </pre>
* @param principalDecrypter response token consumer
*/
public void setPrincipalDecrypter(Consumer<ResponseToken> principalDecrypter) {
Assert.notNull(principalDecrypter, "Consumer<ResponseToken> required");
this.principalDecrypter = principalDecrypter;
}
/** /**
* Construct a default strategy for validating each SAML 2.0 Assertion and associated * Construct a default strategy for validating each SAML 2.0 Assertion and associated
* {@link Authentication} token * {@link Authentication} token
@ -413,7 +422,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
*/ */
public static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator() { public static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator() {
return createDefaultAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
(assertionToken) -> SAML20AssertionValidators.attributeValidator, (assertionToken) -> SAML20AssertionValidators.attributeValidator,
(assertionToken) -> createValidationContext(assertionToken, (params) -> { (assertionToken) -> createValidationContext(assertionToken, (params) -> {
})); }));
@ -430,7 +439,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
public static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator( public static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator(
Converter<AssertionToken, ValidationContext> contextConverter) { Converter<AssertionToken, ValidationContext> contextConverter) {
return createDefaultAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
(assertionToken) -> SAML20AssertionValidators.attributeValidator, contextConverter); (assertionToken) -> SAML20AssertionValidators.attributeValidator, contextConverter);
} }
@ -480,10 +489,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication); return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication);
} }
private Collection<? extends GrantedAuthority> getAssertionAuthorities(Assertion assertion) {
return this.authoritiesExtractor.convert(assertion);
}
private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException { private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException {
try { try {
Document document = this.parserPool Document document = this.parserPool
@ -500,20 +505,30 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
String issuer = response.getIssuer().getValue(); String issuer = response.getIssuer().getValue();
logger.debug(LogMessage.format("Processing SAML response from %s", issuer)); logger.debug(LogMessage.format("Processing SAML response from %s", issuer));
boolean responseSigned = response.isSigned(); boolean responseSigned = response.isSigned();
Saml2ResponseValidatorResult result = validateResponse(token, response);
ResponseToken responseToken = new ResponseToken(response, token); ResponseToken responseToken = new ResponseToken(response, token);
List<Assertion> assertions = decryptAssertions(responseToken); Saml2ResponseValidatorResult result = this.responseSignatureValidator.convert(responseToken);
if (!isSigned(responseSigned, assertions)) { if (responseSigned) {
this.responseElementsDecrypter.accept(responseToken);
}
result = result.concat(this.responseValidator.convert(responseToken));
boolean allAssertionsSigned = true;
for (Assertion assertion : response.getAssertions()) {
AssertionToken assertionToken = new AssertionToken(assertion, token);
result = result.concat(this.assertionSignatureValidator.convert(assertionToken));
allAssertionsSigned = allAssertionsSigned && assertion.isSigned();
if (responseSigned || assertion.isSigned()) {
this.assertionElementsDecrypter.accept(new AssertionToken(assertion, token));
}
result = result.concat(this.assertionValidator.convert(assertionToken));
}
if (!responseSigned && !allAssertionsSigned) {
String description = "Either the response or one of the assertions is unsigned. " String description = "Either the response or one of the assertions is unsigned. "
+ "Please either sign the response or all of the assertions."; + "Please either sign the response or all of the assertions.";
throw createAuthenticationException(Saml2ErrorCodes.INVALID_SIGNATURE, description, null); throw createAuthenticationException(Saml2ErrorCodes.INVALID_SIGNATURE, description, null);
} }
result = result.concat(validateAssertions(token, response));
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions()); Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
NameID nameId = decryptPrincipal(responseToken); if (!hasName(firstAssertion)) {
if (nameId == null || nameId.getValue() == null) {
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND, Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
"Assertion [" + firstAssertion.getID() + "] is missing a subject"); "Assertion [" + firstAssertion.getID() + "] is missing a subject");
result = result.concat(error); result = result.concat(error);
@ -539,107 +554,150 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
} }
private Saml2ResponseValidatorResult validateResponse(Saml2AuthenticationToken token, Response response) { private Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseSignatureValidator() {
return (responseToken) -> {
Collection<Saml2Error> errors = new ArrayList<>(); Response response = responseToken.getResponse();
String issuer = response.getIssuer().getValue(); Saml2AuthenticationToken token = responseToken.getToken();
if (response.isSigned()) { Collection<Saml2Error> errors = new ArrayList<>();
SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); String issuer = response.getIssuer().getValue();
try { if (response.isSigned()) {
profileValidator.validate(response.getSignature()); SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
} try {
catch (Exception ex) { profileValidator.validate(response.getSignature());
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, }
"Invalid signature for SAML Response [" + response.getID() + "]: ")); catch (Exception ex) {
}
try {
CriteriaSet criteriaSet = new CriteriaSet();
criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer)));
criteriaSet.add(
new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), criteriaSet)) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]")); "Invalid signature for SAML Response [" + response.getID() + "]: "));
}
try {
CriteriaSet criteriaSet = new CriteriaSet();
criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer)));
criteriaSet.add(new EvaluableProtocolRoleDescriptorCriterion(
new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(),
criteriaSet)) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]"));
}
}
catch (Exception ex) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]: "));
} }
} }
return Saml2ResponseValidatorResult.failure(errors);
};
}
private Consumer<ResponseToken> createDefaultResponseElementsDecrypter() {
return (responseToken) -> {
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
Response response = responseToken.getResponse();
for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) {
try {
Assertion assertion = decrypter.decrypt(encryptedAssertion);
response.getAssertions().add(assertion);
}
catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
}
};
}
private Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseValidator() {
return (responseToken) -> {
Response response = responseToken.getResponse();
Saml2AuthenticationToken token = responseToken.getToken();
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
String issuer = response.getIssuer().getValue();
String destination = response.getDestination();
String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
if (StringUtils.hasText(destination) && !destination.equals(location)) {
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID()
+ "]";
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
}
String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails()
.getEntityId();
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
}
if (response.getAssertions().isEmpty()) {
throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
"No assertions found in response.", null);
}
return result;
};
}
private Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionSignatureValidator() {
return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> {
SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token);
return SAML20AssertionValidators.createSignatureValidator(engine);
}, (assertionToken) -> new ValidationContext(
Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false)));
}
private Consumer<AssertionToken> createDefaultAssertionElementsDecrypter() {
return (assertionToken) -> {
Decrypter decrypter = this.decrypterConverter.convert(assertionToken.getToken());
Assertion assertion = assertionToken.getAssertion();
if (assertion.getSubject() == null) {
return;
}
if (assertion.getSubject().getEncryptedID() == null) {
return;
}
try {
assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
}
catch (Exception ex) { catch (Exception ex) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
"Invalid signature for SAML Response [" + response.getID() + "]: "));
} }
} };
String destination = response.getDestination();
String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
if (StringUtils.hasText(destination) && !destination.equals(location)) {
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
}
String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId();
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
}
return Saml2ResponseValidatorResult.failure(errors);
} }
private List<Assertion> decryptAssertions(ResponseToken response) { private Converter<AssertionToken, Saml2ResponseValidatorResult> createCompatibleAssertionValidator() {
this.assertionDecrypter.accept(response); return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
return response.getResponse().getAssertions(); (assertionToken) -> SAML20AssertionValidators.attributeValidator,
(assertionToken) -> createValidationContext(assertionToken,
(params) -> params.put(SAML2AssertionValidationParameters.CLOCK_SKEW,
this.responseTimeValidationSkew.toMillis())));
} }
private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) { private Converter<ResponseToken, Saml2Authentication> createCompatibleResponseAuthenticationConverter() {
List<Assertion> assertions = response.getAssertions(); return (responseToken) -> {
if (assertions.isEmpty()) { Response response = responseToken.response;
throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, Saml2AuthenticationToken token = responseToken.token;
"No assertions found in response.", null); Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
} String username = assertion.getSubject().getNameID().getValue();
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success(); return new Saml2Authentication(new DefaultSaml2AuthenticatedPrincipal(username, attributes),
if (logger.isDebugEnabled()) { token.getSaml2Response(),
logger.debug("Validating " + assertions.size() + " assertions"); this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
} };
for (Assertion assertion : assertions) {
if (logger.isTraceEnabled()) {
logger.trace("Validating assertion " + assertion.getID());
}
AssertionToken assertionToken = new AssertionToken(assertion, token);
result = result.concat(this.assertionSignatureValidator.convert(assertionToken))
.concat(this.assertionValidator.convert(assertionToken));
}
return result;
} }
private void addValidationException(Map<String, Saml2AuthenticationException> exceptions, String code, private Collection<? extends GrantedAuthority> getAssertionAuthorities(Assertion assertion) {
String message, Exception cause) { return this.authoritiesExtractor.convert(assertion);
exceptions.put(code, createAuthenticationException(code, message, cause));
} }
private boolean isSigned(boolean responseSigned, List<Assertion> assertions) { private boolean hasName(Assertion assertion) {
if (responseSigned) { if (assertion == null) {
return true; return false;
} }
for (Assertion assertion : assertions) {
if (!assertion.isSigned()) {
return false;
}
}
return true;
}
private NameID decryptPrincipal(ResponseToken responseToken) {
Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
if (assertion.getSubject() == null) { if (assertion.getSubject() == null) {
return null; return false;
} }
if (assertion.getSubject().getEncryptedID() == null) { if (assertion.getSubject().getNameID() == null) {
return assertion.getSubject().getNameID(); return false;
} }
this.principalDecrypter.accept(responseToken); return assertion.getSubject().getNameID().getValue() != null;
return assertion.getSubject().getNameID();
} }
private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) { private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
@ -688,8 +746,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return new Saml2AuthenticationException(new Saml2Error(code, message), cause); return new Saml2AuthenticationException(new Saml2Error(code, message), cause);
} }
private static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator( private static Converter<AssertionToken, Saml2ResponseValidatorResult> createAssertionValidator(String errorCode,
String errorCode, Converter<AssertionToken, SAML20AssertionValidator> validatorConverter, Converter<AssertionToken, SAML20AssertionValidator> validatorConverter,
Converter<AssertionToken, ValidationContext> contextConverter) { Converter<AssertionToken, ValidationContext> contextConverter) {
return (assertionToken) -> { return (assertionToken) -> {

View File

@ -47,6 +47,10 @@ import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.OneTimeUse;
import org.opensaml.saml.saml2.core.Response; import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.impl.EncryptedAssertionBuilder;
import org.opensaml.saml.saml2.core.impl.EncryptedIDBuilder;
import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
import org.opensaml.xmlsec.encryption.impl.EncryptedDataBuilder;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
@ -241,6 +245,8 @@ public class OpenSamlAuthenticationProviderTests {
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartyEncryptingCredential()); TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyDecryptingCredential()); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyDecryptingCredential());
assertThatExceptionOfType(Saml2AuthenticationException.class) assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token)) .isThrownBy(() -> this.provider.authenticate(token))
@ -255,6 +261,8 @@ public class OpenSamlAuthenticationProviderTests {
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion, EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion,
TestSaml2X509Credentials.assertingPartyEncryptingCredential()); TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(), Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(),
TestSaml2X509Credentials.relyingPartyDecryptingCredential()); TestSaml2X509Credentials.relyingPartyDecryptingCredential());
this.provider.authenticate(token); this.provider.authenticate(token);
@ -296,6 +304,8 @@ public class OpenSamlAuthenticationProviderTests {
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartyEncryptingCredential()); TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(serialize(response), Saml2AuthenticationToken token = token(serialize(response),
TestSaml2X509Credentials.relyingPartyVerifyingCredential()); TestSaml2X509Credentials.relyingPartyVerifyingCredential());
assertThatExceptionOfType(Saml2AuthenticationException.class) assertThatExceptionOfType(Saml2AuthenticationException.class)
@ -309,6 +319,8 @@ public class OpenSamlAuthenticationProviderTests {
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartyEncryptingCredential()); TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(serialize(response), Saml2AuthenticationToken token = token(serialize(response),
TestSaml2X509Credentials.assertingPartyPrivateCredential()); TestSaml2X509Credentials.assertingPartyPrivateCredential());
assertThatExceptionOfType(Saml2AuthenticationException.class) assertThatExceptionOfType(Saml2AuthenticationException.class)
@ -324,6 +336,8 @@ public class OpenSamlAuthenticationProviderTests {
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion, EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion,
TestSaml2X509Credentials.assertingPartyEncryptingCredential()); TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(), Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(),
TestSaml2X509Credentials.relyingPartyDecryptingCredential()); TestSaml2X509Credentials.relyingPartyDecryptingCredential());
Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token); Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token);
@ -473,54 +487,48 @@ public class OpenSamlAuthenticationProviderTests {
} }
@Test @Test
public void setAssertionDecrypterWhenNullThenIllegalArgument() { public void setResponseElementsDecrypterWhenNullThenIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionDecrypter(null)); assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setResponseElementsDecrypter(null));
} }
@Test @Test
public void setPrincipalDecrypterWhenNullThenIllegalArgument() { public void setAssertionElementsDecrypterWhenNullThenIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setPrincipalDecrypter(null)); assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionElementsDecrypter(null));
} }
@Test @Test
public void setAssertionDecrypterThenChangesAssertion() { public void authenticateWhenCustomResponseElementsDecrypterThenDecryptsResponse() {
Response response = TestOpenSamlObjects.response(); Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion(); Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getSubjectConfirmations() TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10")); RELYING_PARTY_ENTITY_ID);
response.getEncryptedAssertions().add(new EncryptedAssertionBuilder().buildObject());
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.setResponseElementsDecrypter((tuple) -> tuple.getResponse().getAssertions().add(assertion));
Authentication authentication = this.provider.authenticate(token);
assertThat(authentication.getName()).isEqualTo("test@saml.user");
}
@Test
public void authenticateWhenCustomAssertionElementsDecrypterThenDecryptsAssertion() {
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
EncryptedID id = new EncryptedIDBuilder().buildObject();
id.setEncryptedData(new EncryptedDataBuilder().buildObject());
assertion.getSubject().setEncryptedID(id);
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID); RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.setAssertionDecrypter(mockAssertionAndPrincipalDecrypter()); this.provider.setAssertionElementsDecrypter((tuple) -> {
assertThatExceptionOfType(Saml2AuthenticationException.class) NameID name = new NameIDBuilder().buildObject();
.isThrownBy(() -> this.provider.authenticate(token)) name.setValue("decrypted name");
.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE)); tuple.getAssertion().getSubject().setNameID(name);
assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4"))); });
} Authentication authentication = this.provider.authenticate(token);
assertThat(authentication.getName()).isEqualTo("decrypted name");
@Test
public void setPrincipalDecrypterThenChangesAssertion() {
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getSubjectConfirmations()
.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.setPrincipalDecrypter(mockAssertionAndPrincipalDecrypter());
this.provider.authenticate(token);
assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
}
private Consumer<ResponseToken> mockAssertionAndPrincipalDecrypter() {
return (responseToken) -> {
responseToken.getResponse().getAssertions().clear();
responseToken.getResponse().getAssertions()
.add(TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion("1", "2", "3", "4"),
TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID));
};
} }
private <T extends XMLObject> T build(QName qName) { private <T extends XMLObject> T build(QName qName) {