diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index 26264076ab..cb99cd0df6 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -157,6 +157,7 @@ import org.springframework.util.StringUtils; * asserting party, IDP, verification certificates. *

* + * @author Ryan Cassar * @since 5.2 * @see SAML 2 @@ -211,6 +212,32 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private Converter decrypterConverter = new DecrypterConverter(); + private Consumer assertionDecrypter = (responseToken) -> { + List assertions = new ArrayList<>(); + for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) { + try { + Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken()); + Assertion assertion = decrypter.decrypt(encryptedAssertion); + assertions.add(assertion); + } + catch (DecryptionException ex) { + throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); + } + } + responseToken.getResponse().getAssertions().addAll(assertions); + }; + + private Consumer principalDecrypter = (responseToken) -> { + try { + Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken()); + Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions()); + assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID())); + } + catch (DecryptionException ex) { + throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); + } + }; + /** * Creates an {@link OpenSamlAuthenticationProvider} */ @@ -332,6 +359,52 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi this.responseTimeValidationSkew = responseTimeValidationSkew; } + /** + * Sets the assertion response custom decrypter. + * + * You can use this method like so: + * + *
+	 *	YourDecrypter decrypter = // ... your custom decrypter
+	 *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setAssertionDecrypter((responseToken) -> {
+	 *		Response response = responseToken.getResponse();
+	 *  	EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
+	 *  	Assertion assertion = decrypter.decrypt(encrypted);
+	 *  	response.getAssertions().add(assertion);
+	 *	});
+	 * 
+ * @param assertionDecrypter response token consumer + */ + public void setAssertionDecrypter(Consumer assertionDecrypter) { + Assert.notNull(assertionDecrypter, "Consumer required"); + this.assertionDecrypter = assertionDecrypter; + } + + /** + * Sets the principal custom decrypter. + * + * You can use this method like so: + * + *
+	 *	YourDecrypter decrypter = // ... your custom decrypter
+	 *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setAssertionDecrypter((responseToken) -> {
+	 *		Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
+	 *		EncryptedID encrypted = assertion.getSubject().getEncryptedID();
+	 *		NameID name = decrypter.decrypt(encrypted);
+	 *		assertion.getSubject().setNameID(name)
+	 *	});
+	 * 
+ * @param principalDecrypter response token consumer + */ + public void setPrincipalDecrypter(Consumer principalDecrypter) { + Assert.notNull(principalDecrypter, "Consumer required"); + this.principalDecrypter = principalDecrypter; + } + /** * Construct a default strategy for validating each SAML 2.0 Assertion and associated * {@link Authentication} token @@ -429,8 +502,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi boolean responseSigned = response.isSigned(); Saml2ResponseValidatorResult result = validateResponse(token, response); - Decrypter decrypter = this.decrypterConverter.convert(token); - List assertions = decryptAssertions(decrypter, response); + ResponseToken responseToken = new ResponseToken(response, token); + List assertions = decryptAssertions(responseToken); if (!isSigned(responseSigned, assertions)) { String description = "Either the response or one of the assertions is unsigned. " + "Please either sign the response or all of the assertions."; @@ -439,7 +512,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi result = result.concat(validateAssertions(token, response)); Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions()); - NameID nameId = decryptPrincipal(decrypter, firstAssertion); + NameID nameId = decryptPrincipal(responseToken); if (nameId == null || nameId.getValue() == null) { Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND, "Assertion [" + firstAssertion.getID() + "] is missing a subject"); @@ -511,19 +584,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return Saml2ResponseValidatorResult.failure(errors); } - private List decryptAssertions(Decrypter decrypter, Response response) { - List assertions = new ArrayList<>(); - for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) { - try { - Assertion assertion = decrypter.decrypt(encryptedAssertion); - assertions.add(assertion); - } - catch (DecryptionException ex) { - throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); - } - } - response.getAssertions().addAll(assertions); - return response.getAssertions(); + private List decryptAssertions(ResponseToken response) { + this.assertionDecrypter.accept(response); + return response.getResponse().getAssertions(); } private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) { @@ -567,21 +630,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return true; } - private NameID decryptPrincipal(Decrypter decrypter, Assertion assertion) { + private NameID decryptPrincipal(ResponseToken responseToken) { + Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions()); if (assertion.getSubject() == null) { return null; } if (assertion.getSubject().getEncryptedID() == null) { return assertion.getSubject().getNameID(); } - try { - NameID nameId = (NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()); - assertion.getSubject().setNameID(nameId); - return nameId; - } - catch (DecryptionException ex) { - throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); - } + this.principalDecrypter.accept(responseToken); + return assertion.getSubject().getNameID(); } private static Map> getAssertionAttributes(Assertion assertion) { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index e0ebdb6e25..8cc69c17b5 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -56,6 +56,7 @@ import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; import org.springframework.security.saml2.credentials.Saml2X509Credential; import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.ResponseToken; import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -446,8 +447,7 @@ public class OpenSamlAuthenticationProviderTests { public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() { Response response = TestOpenSamlObjects.signedResponseWithOneAssertion(); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); - OpenSamlAuthenticationProvider.ResponseToken responseToken = new OpenSamlAuthenticationProvider.ResponseToken( - response, token); + ResponseToken responseToken = new ResponseToken(response, token); Saml2Authentication authentication = OpenSamlAuthenticationProvider .createDefaultResponseAuthenticationConverter().convert(responseToken); assertThat(authentication.getName()).isEqualTo("test@saml.user"); @@ -455,8 +455,7 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() { - Converter authenticationConverter = mock( - Converter.class); + Converter authenticationConverter = mock(Converter.class); OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); provider.setResponseAuthenticationConverter(authenticationConverter); Response response = TestOpenSamlObjects.signedResponseWithOneAssertion(); @@ -473,6 +472,57 @@ public class OpenSamlAuthenticationProviderTests { // @formatter:on } + @Test + public void setAssertionDecrypterWhenNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionDecrypter(null)); + } + + @Test + public void setPrincipalDecrypterWhenNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setPrincipalDecrypter(null)); + } + + @Test + public void setAssertionDecrypterThenChangesAssertion() { + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); + assertion.getSubject().getSubjectConfirmations() + .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10")); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); + this.provider.setAssertionDecrypter(mockAssertionAndPrincipalDecrypter()); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE)); + assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4"))); + } + + @Test + public void setPrincipalDecrypterThenChangesAssertion() { + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); + assertion.getSubject().getSubjectConfirmations() + .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10")); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); + this.provider.setPrincipalDecrypter(mockAssertionAndPrincipalDecrypter()); + this.provider.authenticate(token); + assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4"))); + } + + private Consumer mockAssertionAndPrincipalDecrypter() { + return (responseToken) -> { + responseToken.getResponse().getAssertions().clear(); + responseToken.getResponse().getAssertions() + .add(TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion("1", "2", "3", "4"), + TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID)); + }; + } + private T build(QName qName) { return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName); }