diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index 9741b5644f..8c5c153a95 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -69,6 +69,7 @@ import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.Condition; import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; @@ -647,6 +648,17 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return (assertionToken) -> { Decrypter decrypter = this.decrypterConverter.convert(assertionToken.getToken()); Assertion assertion = assertionToken.getAssertion(); + for (AttributeStatement statement : assertion.getAttributeStatements()) { + for (EncryptedAttribute encryptedAttribute : statement.getEncryptedAttributes()) { + try { + Attribute attribute = decrypter.decrypt(encryptedAttribute); + statement.getAttributes().add(attribute); + } + catch (Exception ex) { + throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); + } + } + } if (assertion.getSubject() == null) { return; } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index 0b63ea33e7..2ff963fe81 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -43,6 +43,7 @@ import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.OneTimeUse; @@ -298,6 +299,25 @@ public class OpenSamlAuthenticationProviderTests { this.provider.authenticate(token); } + @Test + public void authenticateWhenEncryptedAttributeThenDecrypts() { + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); + EncryptedAttribute attribute = TestOpenSamlObjects.encrypted("name", "value", + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); + AttributeStatement statement = build(AttributeStatement.DEFAULT_ELEMENT_NAME); + statement.getEncryptedAttributes().add(attribute); + assertion.getAttributeStatements().add(statement); + response.getAssertions().add(assertion); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(), + TestSaml2X509Credentials.relyingPartyDecryptingCredential()); + Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token); + Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); + assertThat(principal.getAttribute("name")).containsExactly("value"); + } + @Test public void authenticateWhenDecryptionKeysAreMissingThenThrowAuthenticationException() throws Exception { Response response = TestOpenSamlObjects.response(); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index c607d12527..96931e5cf0 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -59,6 +59,7 @@ import org.opensaml.saml.saml2.core.AttributeValue; import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.Conditions; import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.Issuer; import org.opensaml.saml.saml2.core.NameID; @@ -301,6 +302,18 @@ public final class TestOpenSamlObjects { } } + static EncryptedAttribute encrypted(String name, String value, Saml2X509Credential credential) { + Attribute attribute = attribute(name, value); + X509Certificate certificate = credential.getCertificate(); + Encrypter encrypter = getEncrypter(certificate); + try { + return encrypter.encrypt(attribute); + } + catch (EncryptionException ex) { + throw new Saml2Exception("Unable to encrypt nameID.", ex); + } + } + private static Encrypter getEncrypter(X509Certificate certificate) { String dataAlgorithm = XMLCipherParameters.AES_256; String keyAlgorithm = XMLCipherParameters.RSA_1_5; @@ -318,6 +331,15 @@ public final class TestOpenSamlObjects { return encrypter; } + static Attribute attribute(String name, String value) { + Attribute attribute = build(Attribute.DEFAULT_ELEMENT_NAME); + attribute.setName(name); + XSString xsValue = new XSStringBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSString.TYPE_NAME); + xsValue.setValue(value); + attribute.getAttributeValues().add(xsValue); + return attribute; + } + static List attributeStatements() { List attributeStatements = new ArrayList<>(); AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder();