diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index cb99cd0df6..9741b5644f 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -88,7 +88,6 @@ import org.opensaml.security.criteria.UsageCriterion; import org.opensaml.security.x509.BasicX509Credential; import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; -import org.opensaml.xmlsec.encryption.support.DecryptionException; import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; @@ -185,59 +184,24 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private Duration responseTimeValidationSkew = Duration.ofMinutes(5); - private Converter responseAuthenticationConverter = ( - responseToken) -> { - Response response = responseToken.response; - Saml2AuthenticationToken token = responseToken.token; - Assertion assertion = CollectionUtils.firstElement(response.getAssertions()); - String username = assertion.getSubject().getNameID().getValue(); - Map> attributes = getAssertionAttributes(assertion); - return new Saml2Authentication(new DefaultSaml2AuthenticatedPrincipal(username, attributes), - token.getSaml2Response(), this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion))); - }; + private Converter responseSignatureValidator = createDefaultResponseSignatureValidator(); - private Converter assertionSignatureValidator = createDefaultAssertionValidator( - Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> { - SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token); - return SAML20AssertionValidators.createSignatureValidator(engine); - }, (assertionToken) -> new ValidationContext( - Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false))); + private Consumer responseElementsDecrypter = createDefaultResponseElementsDecrypter(); - private Converter assertionValidator = createDefaultAssertionValidator( - Saml2ErrorCodes.INVALID_ASSERTION, (assertionToken) -> SAML20AssertionValidators.attributeValidator, - (assertionToken) -> createValidationContext(assertionToken, (params) -> params - .put(SAML2AssertionValidationParameters.CLOCK_SKEW, this.responseTimeValidationSkew.toMillis()))); + private Converter responseValidator = createDefaultResponseValidator(); + + private Converter assertionSignatureValidator = createDefaultAssertionSignatureValidator(); + + private Consumer assertionElementsDecrypter = createDefaultAssertionElementsDecrypter(); + + private Converter assertionValidator = createCompatibleAssertionValidator(); + + private Converter responseAuthenticationConverter = createCompatibleResponseAuthenticationConverter(); private Converter signatureTrustEngineConverter = new SignatureTrustEngineConverter(); private Converter decrypterConverter = new DecrypterConverter(); - private Consumer assertionDecrypter = (responseToken) -> { - List assertions = new ArrayList<>(); - for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) { - try { - Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken()); - Assertion assertion = decrypter.decrypt(encryptedAssertion); - assertions.add(assertion); - } - catch (DecryptionException ex) { - throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); - } - } - responseToken.getResponse().getAssertions().addAll(assertions); - }; - - private Consumer principalDecrypter = (responseToken) -> { - try { - Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken()); - Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions()); - assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID())); - } - catch (DecryptionException ex) { - throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); - } - }; - /** * Creates an {@link OpenSamlAuthenticationProvider} */ @@ -248,12 +212,60 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi this.parserPool = this.registry.getParserPool(); } + /** + * Set the {@link Consumer} strategy to use for decrypting elements of a validated + * {@link Response}. The default strategy decrypts all {@link EncryptedAssertion}s + * using OpenSAML's {@link Decrypter}, adding the results to + * {@link Response#getAssertions()}. + * + * You can use this method to configure the {@link Decrypter} instance like so: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setResponseElementsDecrypter((responseToken) -> {
+	 *	    DecrypterParameters parameters = new DecrypterParameters();
+	 *	    // ... set parameters as needed
+	 *	    Decrypter decrypter = new Decrypter(parameters);
+	 *		Response response = responseToken.getResponse();
+	 *  	EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
+	 *  	try {
+	 *  		Assertion assertion = decrypter.decrypt(encrypted);
+	 *  		response.getAssertions().add(assertion);
+	 *  	} catch (Exception e) {
+	 *  	 	throw new Saml2AuthenticationException(...);
+	 *  	}
+	 *	});
+	 * 
+ * + * Or, in the event that you have your own custom decryption interface, the same + * pattern applies: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	Converter<EncryptedAssertion, Assertion> myService = ...
+	 *	provider.setResponseDecrypter((responseToken) -> {
+	 *	   Response response = responseToken.getResponse();
+	 *	   response.getEncryptedAssertions().stream()
+	 *	   		.map(service::decrypt).forEach(response.getAssertions()::add);
+	 *	});
+	 * 
+ * + * This is valuable when using an external service to perform the decryption. + * @param responseElementsDecrypter the {@link Consumer} for decrypting response + * elements + * @since 5.5 + */ + public void setResponseElementsDecrypter(Consumer responseElementsDecrypter) { + Assert.notNull(responseElementsDecrypter, "responseElementsDecrypter cannot be null"); + this.responseElementsDecrypter = responseElementsDecrypter; + } + /** * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML * 2.0 Response. * * You can still invoke the default validator by delgating to - * {@link #createDefaultAssertionValidator}, like so: + * {@link #createAssertionValidator}, like so: * *
 	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
@@ -294,6 +306,49 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		this.assertionValidator = assertionValidator;
 	}
 
+	/**
+	 * Set the {@link Consumer} strategy to use for decrypting elements of a validated
+	 * {@link Assertion}.
+	 *
+	 * You can use this method to configure the {@link Decrypter} used like so:
+	 *
+	 * 
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setResponseDecrypter((assertionToken) -> {
+	 *	    DecrypterParameters parameters = new DecrypterParameters();
+	 *	    // ... set parameters as needed
+	 *	    Decrypter decrypter = new Decrypter(parameters);
+	 *		Assertion assertion = assertionToken.getAssertion();
+	 *  	EncryptedID encrypted = assertion.getSubject().getEncryptedID();
+	 *  	try {
+	 *  		NameID name = decrypter.decrypt(encrypted);
+	 *  		assertion.getSubject().setNameID(name);
+	 *  	} catch (Exception e) {
+	 *  	 	throw new Saml2AuthenticationException(...);
+	 *  	}
+	 *	});
+	 * 
+ * + * Or, in the event that you have your own custom interface, the same pattern applies: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	MyDecryptionService myService = ...
+	 *	provider.setResponseDecrypter((responseToken) -> {
+	 *	   	Assertion assertion = assertionToken.getAssertion();
+	 *	   	EncryptedID encrypted = assertion.getSubject().getEncryptedID();
+	 *		NameID name = myService.decrypt(encrypted);
+	 *		assertion.getSubject().setNameID(name);
+	 *	});
+	 * 
+ * @param assertionDecrypter the {@link Consumer} for decrypting assertion elements + * @since 5.5 + */ + public void setAssertionElementsDecrypter(Consumer assertionDecrypter) { + Assert.notNull(assertionDecrypter, "assertionDecrypter cannot be null"); + this.assertionElementsDecrypter = assertionDecrypter; + } + /** * Set the {@link Converter} to use for converting a validated {@link Response} into * an {@link AbstractAuthenticationToken}. @@ -359,52 +414,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi this.responseTimeValidationSkew = responseTimeValidationSkew; } - /** - * Sets the assertion response custom decrypter. - * - * You can use this method like so: - * - *
-	 *	YourDecrypter decrypter = // ... your custom decrypter
-	 *
-	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
-	 *	provider.setAssertionDecrypter((responseToken) -> {
-	 *		Response response = responseToken.getResponse();
-	 *  	EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
-	 *  	Assertion assertion = decrypter.decrypt(encrypted);
-	 *  	response.getAssertions().add(assertion);
-	 *	});
-	 * 
- * @param assertionDecrypter response token consumer - */ - public void setAssertionDecrypter(Consumer assertionDecrypter) { - Assert.notNull(assertionDecrypter, "Consumer required"); - this.assertionDecrypter = assertionDecrypter; - } - - /** - * Sets the principal custom decrypter. - * - * You can use this method like so: - * - *
-	 *	YourDecrypter decrypter = // ... your custom decrypter
-	 *
-	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
-	 *	provider.setAssertionDecrypter((responseToken) -> {
-	 *		Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
-	 *		EncryptedID encrypted = assertion.getSubject().getEncryptedID();
-	 *		NameID name = decrypter.decrypt(encrypted);
-	 *		assertion.getSubject().setNameID(name)
-	 *	});
-	 * 
- * @param principalDecrypter response token consumer - */ - public void setPrincipalDecrypter(Consumer principalDecrypter) { - Assert.notNull(principalDecrypter, "Consumer required"); - this.principalDecrypter = principalDecrypter; - } - /** * Construct a default strategy for validating each SAML 2.0 Assertion and associated * {@link Authentication} token @@ -413,7 +422,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi */ public static Converter createDefaultAssertionValidator() { - return createDefaultAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, + return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, (assertionToken) -> SAML20AssertionValidators.attributeValidator, (assertionToken) -> createValidationContext(assertionToken, (params) -> { })); @@ -430,7 +439,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi public static Converter createDefaultAssertionValidator( Converter contextConverter) { - return createDefaultAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, + return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, (assertionToken) -> SAML20AssertionValidators.attributeValidator, contextConverter); } @@ -480,10 +489,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication); } - private Collection getAssertionAuthorities(Assertion assertion) { - return this.authoritiesExtractor.convert(assertion); - } - private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException { try { Document document = this.parserPool @@ -500,20 +505,30 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi String issuer = response.getIssuer().getValue(); logger.debug(LogMessage.format("Processing SAML response from %s", issuer)); boolean responseSigned = response.isSigned(); - Saml2ResponseValidatorResult result = validateResponse(token, response); ResponseToken responseToken = new ResponseToken(response, token); - List assertions = decryptAssertions(responseToken); - if (!isSigned(responseSigned, assertions)) { + Saml2ResponseValidatorResult result = this.responseSignatureValidator.convert(responseToken); + if (responseSigned) { + this.responseElementsDecrypter.accept(responseToken); + } + result = result.concat(this.responseValidator.convert(responseToken)); + boolean allAssertionsSigned = true; + for (Assertion assertion : response.getAssertions()) { + AssertionToken assertionToken = new AssertionToken(assertion, token); + result = result.concat(this.assertionSignatureValidator.convert(assertionToken)); + allAssertionsSigned = allAssertionsSigned && assertion.isSigned(); + if (responseSigned || assertion.isSigned()) { + this.assertionElementsDecrypter.accept(new AssertionToken(assertion, token)); + } + result = result.concat(this.assertionValidator.convert(assertionToken)); + } + if (!responseSigned && !allAssertionsSigned) { String description = "Either the response or one of the assertions is unsigned. " + "Please either sign the response or all of the assertions."; throw createAuthenticationException(Saml2ErrorCodes.INVALID_SIGNATURE, description, null); } - result = result.concat(validateAssertions(token, response)); - Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions()); - NameID nameId = decryptPrincipal(responseToken); - if (nameId == null || nameId.getValue() == null) { + if (!hasName(firstAssertion)) { Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND, "Assertion [" + firstAssertion.getID() + "] is missing a subject"); result = result.concat(error); @@ -539,107 +554,150 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } } - private Saml2ResponseValidatorResult validateResponse(Saml2AuthenticationToken token, Response response) { - - Collection errors = new ArrayList<>(); - String issuer = response.getIssuer().getValue(); - if (response.isSigned()) { - SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); - try { - profileValidator.validate(response.getSignature()); - } - catch (Exception ex) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for SAML Response [" + response.getID() + "]: ")); - } - - try { - CriteriaSet criteriaSet = new CriteriaSet(); - criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer))); - criteriaSet.add( - new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS))); - criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); - if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), criteriaSet)) { + private Converter createDefaultResponseSignatureValidator() { + return (responseToken) -> { + Response response = responseToken.getResponse(); + Saml2AuthenticationToken token = responseToken.getToken(); + Collection errors = new ArrayList<>(); + String issuer = response.getIssuer().getValue(); + if (response.isSigned()) { + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(response.getSignature()); + } + catch (Exception ex) { errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for SAML Response [" + response.getID() + "]")); + "Invalid signature for SAML Response [" + response.getID() + "]: ")); + } + + try { + CriteriaSet criteriaSet = new CriteriaSet(); + criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer))); + criteriaSet.add(new EvaluableProtocolRoleDescriptorCriterion( + new ProtocolCriterion(SAMLConstants.SAML20P_NS))); + criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), + criteriaSet)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for SAML Response [" + response.getID() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for SAML Response [" + response.getID() + "]: ")); } } + + return Saml2ResponseValidatorResult.failure(errors); + }; + } + + private Consumer createDefaultResponseElementsDecrypter() { + return (responseToken) -> { + Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken()); + Response response = responseToken.getResponse(); + for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) { + try { + Assertion assertion = decrypter.decrypt(encryptedAssertion); + response.getAssertions().add(assertion); + } + catch (Exception ex) { + throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); + } + } + }; + } + + private Converter createDefaultResponseValidator() { + return (responseToken) -> { + Response response = responseToken.getResponse(); + Saml2AuthenticationToken token = responseToken.getToken(); + Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success(); + String issuer = response.getIssuer().getValue(); + String destination = response.getDestination(); + String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); + if (StringUtils.hasText(destination) && !destination.equals(location)) { + String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + + "]"; + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message)); + } + String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails() + .getEntityId(); + if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) { + String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID()); + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message)); + } + if (response.getAssertions().isEmpty()) { + throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, + "No assertions found in response.", null); + } + return result; + }; + } + + private Converter createDefaultAssertionSignatureValidator() { + return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> { + SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token); + return SAML20AssertionValidators.createSignatureValidator(engine); + }, (assertionToken) -> new ValidationContext( + Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false))); + } + + private Consumer createDefaultAssertionElementsDecrypter() { + return (assertionToken) -> { + Decrypter decrypter = this.decrypterConverter.convert(assertionToken.getToken()); + Assertion assertion = assertionToken.getAssertion(); + if (assertion.getSubject() == null) { + return; + } + if (assertion.getSubject().getEncryptedID() == null) { + return; + } + try { + assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID())); + } catch (Exception ex) { - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, - "Invalid signature for SAML Response [" + response.getID() + "]: ")); + throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); } - } - String destination = response.getDestination(); - String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); - if (StringUtils.hasText(destination) && !destination.equals(location)) { - String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]"; - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message)); - } - String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId(); - if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) { - String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID()); - errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message)); - } - - return Saml2ResponseValidatorResult.failure(errors); + }; } - private List decryptAssertions(ResponseToken response) { - this.assertionDecrypter.accept(response); - return response.getResponse().getAssertions(); + private Converter createCompatibleAssertionValidator() { + return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, + (assertionToken) -> SAML20AssertionValidators.attributeValidator, + (assertionToken) -> createValidationContext(assertionToken, + (params) -> params.put(SAML2AssertionValidationParameters.CLOCK_SKEW, + this.responseTimeValidationSkew.toMillis()))); } - private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) { - List assertions = response.getAssertions(); - if (assertions.isEmpty()) { - throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, - "No assertions found in response.", null); - } - - Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success(); - if (logger.isDebugEnabled()) { - logger.debug("Validating " + assertions.size() + " assertions"); - } - - for (Assertion assertion : assertions) { - if (logger.isTraceEnabled()) { - logger.trace("Validating assertion " + assertion.getID()); - } - AssertionToken assertionToken = new AssertionToken(assertion, token); - result = result.concat(this.assertionSignatureValidator.convert(assertionToken)) - .concat(this.assertionValidator.convert(assertionToken)); - } - - return result; + private Converter createCompatibleResponseAuthenticationConverter() { + return (responseToken) -> { + Response response = responseToken.response; + Saml2AuthenticationToken token = responseToken.token; + Assertion assertion = CollectionUtils.firstElement(response.getAssertions()); + String username = assertion.getSubject().getNameID().getValue(); + Map> attributes = getAssertionAttributes(assertion); + return new Saml2Authentication(new DefaultSaml2AuthenticatedPrincipal(username, attributes), + token.getSaml2Response(), + this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion))); + }; } - private void addValidationException(Map exceptions, String code, - String message, Exception cause) { - exceptions.put(code, createAuthenticationException(code, message, cause)); + private Collection getAssertionAuthorities(Assertion assertion) { + return this.authoritiesExtractor.convert(assertion); } - private boolean isSigned(boolean responseSigned, List assertions) { - if (responseSigned) { - return true; + private boolean hasName(Assertion assertion) { + if (assertion == null) { + return false; } - for (Assertion assertion : assertions) { - if (!assertion.isSigned()) { - return false; - } - } - return true; - } - - private NameID decryptPrincipal(ResponseToken responseToken) { - Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions()); if (assertion.getSubject() == null) { - return null; + return false; } - if (assertion.getSubject().getEncryptedID() == null) { - return assertion.getSubject().getNameID(); + if (assertion.getSubject().getNameID() == null) { + return false; } - this.principalDecrypter.accept(responseToken); - return assertion.getSubject().getNameID(); + return assertion.getSubject().getNameID().getValue() != null; } private static Map> getAssertionAttributes(Assertion assertion) { @@ -688,8 +746,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return new Saml2AuthenticationException(new Saml2Error(code, message), cause); } - private static Converter createDefaultAssertionValidator( - String errorCode, Converter validatorConverter, + private static Converter createAssertionValidator(String errorCode, + Converter validatorConverter, Converter contextConverter) { return (assertionToken) -> { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index 8cc69c17b5..0b63ea33e7 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -47,6 +47,10 @@ import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.impl.EncryptedAssertionBuilder; +import org.opensaml.saml.saml2.core.impl.EncryptedIDBuilder; +import org.opensaml.saml.saml2.core.impl.NameIDBuilder; +import org.opensaml.xmlsec.encryption.impl.EncryptedDataBuilder; import org.w3c.dom.Element; import org.springframework.core.convert.converter.Converter; @@ -241,6 +245,8 @@ public class OpenSamlAuthenticationProviderTests { EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyDecryptingCredential()); assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> this.provider.authenticate(token)) @@ -255,6 +261,8 @@ public class OpenSamlAuthenticationProviderTests { EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion, TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(), TestSaml2X509Credentials.relyingPartyDecryptingCredential()); this.provider.authenticate(token); @@ -296,6 +304,8 @@ public class OpenSamlAuthenticationProviderTests { EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(serialize(response), TestSaml2X509Credentials.relyingPartyVerifyingCredential()); assertThatExceptionOfType(Saml2AuthenticationException.class) @@ -309,6 +319,8 @@ public class OpenSamlAuthenticationProviderTests { EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(serialize(response), TestSaml2X509Credentials.assertingPartyPrivateCredential()); assertThatExceptionOfType(Saml2AuthenticationException.class) @@ -324,6 +336,8 @@ public class OpenSamlAuthenticationProviderTests { EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion, TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(), TestSaml2X509Credentials.relyingPartyDecryptingCredential()); Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token); @@ -473,54 +487,48 @@ public class OpenSamlAuthenticationProviderTests { } @Test - public void setAssertionDecrypterWhenNullThenIllegalArgument() { - assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionDecrypter(null)); + public void setResponseElementsDecrypterWhenNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setResponseElementsDecrypter(null)); } @Test - public void setPrincipalDecrypterWhenNullThenIllegalArgument() { - assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setPrincipalDecrypter(null)); + public void setAssertionElementsDecrypterWhenNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionElementsDecrypter(null)); } @Test - public void setAssertionDecrypterThenChangesAssertion() { + public void authenticateWhenCustomResponseElementsDecrypterThenDecryptsResponse() { Response response = TestOpenSamlObjects.response(); Assertion assertion = TestOpenSamlObjects.assertion(); - assertion.getSubject().getSubjectConfirmations() - .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10")); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + response.getEncryptedAssertions().add(new EncryptedAssertionBuilder().buildObject()); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); + this.provider.setResponseElementsDecrypter((tuple) -> tuple.getResponse().getAssertions().add(assertion)); + Authentication authentication = this.provider.authenticate(token); + assertThat(authentication.getName()).isEqualTo("test@saml.user"); + } + + @Test + public void authenticateWhenCustomAssertionElementsDecrypterThenDecryptsAssertion() { + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); + EncryptedID id = new EncryptedIDBuilder().buildObject(); + id.setEncryptedData(new EncryptedDataBuilder().buildObject()); + assertion.getSubject().setEncryptedID(id); TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); - this.provider.setAssertionDecrypter(mockAssertionAndPrincipalDecrypter()); - assertThatExceptionOfType(Saml2AuthenticationException.class) - .isThrownBy(() -> this.provider.authenticate(token)) - .satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE)); - assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4"))); - } - - @Test - public void setPrincipalDecrypterThenChangesAssertion() { - Response response = TestOpenSamlObjects.response(); - Assertion assertion = TestOpenSamlObjects.assertion(); - assertion.getSubject().getSubjectConfirmations() - .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10")); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - response.getAssertions().add(assertion); - Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); - this.provider.setPrincipalDecrypter(mockAssertionAndPrincipalDecrypter()); - this.provider.authenticate(token); - assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4"))); - } - - private Consumer mockAssertionAndPrincipalDecrypter() { - return (responseToken) -> { - responseToken.getResponse().getAssertions().clear(); - responseToken.getResponse().getAssertions() - .add(TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion("1", "2", "3", "4"), - TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID)); - }; + this.provider.setAssertionElementsDecrypter((tuple) -> { + NameID name = new NameIDBuilder().buildObject(); + name.setValue("decrypted name"); + tuple.getAssertion().getSubject().setNameID(name); + }); + Authentication authentication = this.provider.authenticate(token); + assertThat(authentication.getName()).isEqualTo("decrypted name"); } private T build(QName qName) {