Add Configurable SAML Response Decryption

Closes gh-9044
This commit is contained in:
ryan.cassar 2020-09-28 09:39:42 +02:00 committed by Josh Cummings
parent 1436ce493e
commit 535ae3e27d
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 137 additions and 29 deletions

View File

@ -157,6 +157,7 @@ import org.springframework.util.StringUtils;
* asserting party, IDP, verification certificates. * asserting party, IDP, verification certificates.
* </p> * </p>
* *
* @author Ryan Cassar
* @since 5.2 * @since 5.2
* @see <a href= * @see <a href=
* "https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38">SAML 2 * "https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38">SAML 2
@ -211,6 +212,32 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter(); private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
private Consumer<ResponseToken> assertionDecrypter = (responseToken) -> {
List<Assertion> assertions = new ArrayList<>();
for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) {
try {
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
Assertion assertion = decrypter.decrypt(encryptedAssertion);
assertions.add(assertion);
}
catch (DecryptionException ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
}
responseToken.getResponse().getAssertions().addAll(assertions);
};
private Consumer<ResponseToken> principalDecrypter = (responseToken) -> {
try {
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
}
catch (DecryptionException ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
};
/** /**
* Creates an {@link OpenSamlAuthenticationProvider} * Creates an {@link OpenSamlAuthenticationProvider}
*/ */
@ -332,6 +359,52 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this.responseTimeValidationSkew = responseTimeValidationSkew; this.responseTimeValidationSkew = responseTimeValidationSkew;
} }
/**
* Sets the assertion response custom decrypter.
*
* You can use this method like so:
*
* <pre>
* YourDecrypter decrypter = // ... your custom decrypter
*
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* provider.setAssertionDecrypter((responseToken) -> {
* Response response = responseToken.getResponse();
* EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
* Assertion assertion = decrypter.decrypt(encrypted);
* response.getAssertions().add(assertion);
* });
* </pre>
* @param assertionDecrypter response token consumer
*/
public void setAssertionDecrypter(Consumer<ResponseToken> assertionDecrypter) {
Assert.notNull(assertionDecrypter, "Consumer<ResponseToken> required");
this.assertionDecrypter = assertionDecrypter;
}
/**
* Sets the principal custom decrypter.
*
* You can use this method like so:
*
* <pre>
* YourDecrypter decrypter = // ... your custom decrypter
*
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* provider.setAssertionDecrypter((responseToken) -> {
* Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
* EncryptedID encrypted = assertion.getSubject().getEncryptedID();
* NameID name = decrypter.decrypt(encrypted);
* assertion.getSubject().setNameID(name)
* });
* </pre>
* @param principalDecrypter response token consumer
*/
public void setPrincipalDecrypter(Consumer<ResponseToken> principalDecrypter) {
Assert.notNull(principalDecrypter, "Consumer<ResponseToken> required");
this.principalDecrypter = principalDecrypter;
}
/** /**
* Construct a default strategy for validating each SAML 2.0 Assertion and associated * Construct a default strategy for validating each SAML 2.0 Assertion and associated
* {@link Authentication} token * {@link Authentication} token
@ -429,8 +502,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
boolean responseSigned = response.isSigned(); boolean responseSigned = response.isSigned();
Saml2ResponseValidatorResult result = validateResponse(token, response); Saml2ResponseValidatorResult result = validateResponse(token, response);
Decrypter decrypter = this.decrypterConverter.convert(token); ResponseToken responseToken = new ResponseToken(response, token);
List<Assertion> assertions = decryptAssertions(decrypter, response); List<Assertion> assertions = decryptAssertions(responseToken);
if (!isSigned(responseSigned, assertions)) { if (!isSigned(responseSigned, assertions)) {
String description = "Either the response or one of the assertions is unsigned. " String description = "Either the response or one of the assertions is unsigned. "
+ "Please either sign the response or all of the assertions."; + "Please either sign the response or all of the assertions.";
@ -439,7 +512,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
result = result.concat(validateAssertions(token, response)); result = result.concat(validateAssertions(token, response));
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions()); Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
NameID nameId = decryptPrincipal(decrypter, firstAssertion); NameID nameId = decryptPrincipal(responseToken);
if (nameId == null || nameId.getValue() == null) { if (nameId == null || nameId.getValue() == null) {
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND, Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
"Assertion [" + firstAssertion.getID() + "] is missing a subject"); "Assertion [" + firstAssertion.getID() + "] is missing a subject");
@ -511,19 +584,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return Saml2ResponseValidatorResult.failure(errors); return Saml2ResponseValidatorResult.failure(errors);
} }
private List<Assertion> decryptAssertions(Decrypter decrypter, Response response) { private List<Assertion> decryptAssertions(ResponseToken response) {
List<Assertion> assertions = new ArrayList<>(); this.assertionDecrypter.accept(response);
for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) { return response.getResponse().getAssertions();
try {
Assertion assertion = decrypter.decrypt(encryptedAssertion);
assertions.add(assertion);
}
catch (DecryptionException ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
}
response.getAssertions().addAll(assertions);
return response.getAssertions();
} }
private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) { private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) {
@ -567,21 +630,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return true; return true;
} }
private NameID decryptPrincipal(Decrypter decrypter, Assertion assertion) { private NameID decryptPrincipal(ResponseToken responseToken) {
Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
if (assertion.getSubject() == null) { if (assertion.getSubject() == null) {
return null; return null;
} }
if (assertion.getSubject().getEncryptedID() == null) { if (assertion.getSubject().getEncryptedID() == null) {
return assertion.getSubject().getNameID(); return assertion.getSubject().getNameID();
} }
try { this.principalDecrypter.accept(responseToken);
NameID nameId = (NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()); return assertion.getSubject().getNameID();
assertion.getSubject().setNameID(nameId);
return nameId;
}
catch (DecryptionException ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
} }
private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) { private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {

View File

@ -56,6 +56,7 @@ import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
import org.springframework.security.saml2.credentials.Saml2X509Credential; import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.ResponseToken;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -446,8 +447,7 @@ public class OpenSamlAuthenticationProviderTests {
public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() { public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() {
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion(); Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
OpenSamlAuthenticationProvider.ResponseToken responseToken = new OpenSamlAuthenticationProvider.ResponseToken( ResponseToken responseToken = new ResponseToken(response, token);
response, token);
Saml2Authentication authentication = OpenSamlAuthenticationProvider Saml2Authentication authentication = OpenSamlAuthenticationProvider
.createDefaultResponseAuthenticationConverter().convert(responseToken); .createDefaultResponseAuthenticationConverter().convert(responseToken);
assertThat(authentication.getName()).isEqualTo("test@saml.user"); assertThat(authentication.getName()).isEqualTo("test@saml.user");
@ -455,8 +455,7 @@ public class OpenSamlAuthenticationProviderTests {
@Test @Test
public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() { public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() {
Converter<OpenSamlAuthenticationProvider.ResponseToken, Saml2Authentication> authenticationConverter = mock( Converter<ResponseToken, Saml2Authentication> authenticationConverter = mock(Converter.class);
Converter.class);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setResponseAuthenticationConverter(authenticationConverter); provider.setResponseAuthenticationConverter(authenticationConverter);
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion(); Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
@ -473,6 +472,57 @@ public class OpenSamlAuthenticationProviderTests {
// @formatter:on // @formatter:on
} }
@Test
public void setAssertionDecrypterWhenNullThenIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionDecrypter(null));
}
@Test
public void setPrincipalDecrypterWhenNullThenIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setPrincipalDecrypter(null));
}
@Test
public void setAssertionDecrypterThenChangesAssertion() {
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getSubjectConfirmations()
.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.setAssertionDecrypter(mockAssertionAndPrincipalDecrypter());
assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
}
@Test
public void setPrincipalDecrypterThenChangesAssertion() {
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getSubjectConfirmations()
.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.setPrincipalDecrypter(mockAssertionAndPrincipalDecrypter());
this.provider.authenticate(token);
assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
}
private Consumer<ResponseToken> mockAssertionAndPrincipalDecrypter() {
return (responseToken) -> {
responseToken.getResponse().getAssertions().clear();
responseToken.getResponse().getAssertions()
.add(TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion("1", "2", "3", "4"),
TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID));
};
}
private <T extends XMLObject> T build(QName qName) { private <T extends XMLObject> T build(QName qName) {
return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName); return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
} }