parent
1436ce493e
commit
535ae3e27d
|
@ -157,6 +157,7 @@ import org.springframework.util.StringUtils;
|
||||||
* asserting party, IDP, verification certificates.
|
* asserting party, IDP, verification certificates.
|
||||||
* </p>
|
* </p>
|
||||||
*
|
*
|
||||||
|
* @author Ryan Cassar
|
||||||
* @since 5.2
|
* @since 5.2
|
||||||
* @see <a href=
|
* @see <a href=
|
||||||
* "https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38">SAML 2
|
* "https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38">SAML 2
|
||||||
|
@ -211,6 +212,32 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
||||||
|
|
||||||
private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
|
private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
|
||||||
|
|
||||||
|
private Consumer<ResponseToken> assertionDecrypter = (responseToken) -> {
|
||||||
|
List<Assertion> assertions = new ArrayList<>();
|
||||||
|
for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) {
|
||||||
|
try {
|
||||||
|
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
|
||||||
|
Assertion assertion = decrypter.decrypt(encryptedAssertion);
|
||||||
|
assertions.add(assertion);
|
||||||
|
}
|
||||||
|
catch (DecryptionException ex) {
|
||||||
|
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
responseToken.getResponse().getAssertions().addAll(assertions);
|
||||||
|
};
|
||||||
|
|
||||||
|
private Consumer<ResponseToken> principalDecrypter = (responseToken) -> {
|
||||||
|
try {
|
||||||
|
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
|
||||||
|
Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
|
||||||
|
assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
|
||||||
|
}
|
||||||
|
catch (DecryptionException ex) {
|
||||||
|
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an {@link OpenSamlAuthenticationProvider}
|
* Creates an {@link OpenSamlAuthenticationProvider}
|
||||||
*/
|
*/
|
||||||
|
@ -332,6 +359,52 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
||||||
this.responseTimeValidationSkew = responseTimeValidationSkew;
|
this.responseTimeValidationSkew = responseTimeValidationSkew;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the assertion response custom decrypter.
|
||||||
|
*
|
||||||
|
* You can use this method like so:
|
||||||
|
*
|
||||||
|
* <pre>
|
||||||
|
* YourDecrypter decrypter = // ... your custom decrypter
|
||||||
|
*
|
||||||
|
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
|
||||||
|
* provider.setAssertionDecrypter((responseToken) -> {
|
||||||
|
* Response response = responseToken.getResponse();
|
||||||
|
* EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
|
||||||
|
* Assertion assertion = decrypter.decrypt(encrypted);
|
||||||
|
* response.getAssertions().add(assertion);
|
||||||
|
* });
|
||||||
|
* </pre>
|
||||||
|
* @param assertionDecrypter response token consumer
|
||||||
|
*/
|
||||||
|
public void setAssertionDecrypter(Consumer<ResponseToken> assertionDecrypter) {
|
||||||
|
Assert.notNull(assertionDecrypter, "Consumer<ResponseToken> required");
|
||||||
|
this.assertionDecrypter = assertionDecrypter;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the principal custom decrypter.
|
||||||
|
*
|
||||||
|
* You can use this method like so:
|
||||||
|
*
|
||||||
|
* <pre>
|
||||||
|
* YourDecrypter decrypter = // ... your custom decrypter
|
||||||
|
*
|
||||||
|
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
|
||||||
|
* provider.setAssertionDecrypter((responseToken) -> {
|
||||||
|
* Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
|
||||||
|
* EncryptedID encrypted = assertion.getSubject().getEncryptedID();
|
||||||
|
* NameID name = decrypter.decrypt(encrypted);
|
||||||
|
* assertion.getSubject().setNameID(name)
|
||||||
|
* });
|
||||||
|
* </pre>
|
||||||
|
* @param principalDecrypter response token consumer
|
||||||
|
*/
|
||||||
|
public void setPrincipalDecrypter(Consumer<ResponseToken> principalDecrypter) {
|
||||||
|
Assert.notNull(principalDecrypter, "Consumer<ResponseToken> required");
|
||||||
|
this.principalDecrypter = principalDecrypter;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct a default strategy for validating each SAML 2.0 Assertion and associated
|
* Construct a default strategy for validating each SAML 2.0 Assertion and associated
|
||||||
* {@link Authentication} token
|
* {@link Authentication} token
|
||||||
|
@ -429,8 +502,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
||||||
boolean responseSigned = response.isSigned();
|
boolean responseSigned = response.isSigned();
|
||||||
Saml2ResponseValidatorResult result = validateResponse(token, response);
|
Saml2ResponseValidatorResult result = validateResponse(token, response);
|
||||||
|
|
||||||
Decrypter decrypter = this.decrypterConverter.convert(token);
|
ResponseToken responseToken = new ResponseToken(response, token);
|
||||||
List<Assertion> assertions = decryptAssertions(decrypter, response);
|
List<Assertion> assertions = decryptAssertions(responseToken);
|
||||||
if (!isSigned(responseSigned, assertions)) {
|
if (!isSigned(responseSigned, assertions)) {
|
||||||
String description = "Either the response or one of the assertions is unsigned. "
|
String description = "Either the response or one of the assertions is unsigned. "
|
||||||
+ "Please either sign the response or all of the assertions.";
|
+ "Please either sign the response or all of the assertions.";
|
||||||
|
@ -439,7 +512,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
||||||
result = result.concat(validateAssertions(token, response));
|
result = result.concat(validateAssertions(token, response));
|
||||||
|
|
||||||
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
|
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
|
||||||
NameID nameId = decryptPrincipal(decrypter, firstAssertion);
|
NameID nameId = decryptPrincipal(responseToken);
|
||||||
if (nameId == null || nameId.getValue() == null) {
|
if (nameId == null || nameId.getValue() == null) {
|
||||||
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
|
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
|
||||||
"Assertion [" + firstAssertion.getID() + "] is missing a subject");
|
"Assertion [" + firstAssertion.getID() + "] is missing a subject");
|
||||||
|
@ -511,19 +584,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
||||||
return Saml2ResponseValidatorResult.failure(errors);
|
return Saml2ResponseValidatorResult.failure(errors);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Assertion> decryptAssertions(Decrypter decrypter, Response response) {
|
private List<Assertion> decryptAssertions(ResponseToken response) {
|
||||||
List<Assertion> assertions = new ArrayList<>();
|
this.assertionDecrypter.accept(response);
|
||||||
for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
|
return response.getResponse().getAssertions();
|
||||||
try {
|
|
||||||
Assertion assertion = decrypter.decrypt(encryptedAssertion);
|
|
||||||
assertions.add(assertion);
|
|
||||||
}
|
|
||||||
catch (DecryptionException ex) {
|
|
||||||
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
response.getAssertions().addAll(assertions);
|
|
||||||
return response.getAssertions();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) {
|
private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) {
|
||||||
|
@ -567,21 +630,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private NameID decryptPrincipal(Decrypter decrypter, Assertion assertion) {
|
private NameID decryptPrincipal(ResponseToken responseToken) {
|
||||||
|
Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
|
||||||
if (assertion.getSubject() == null) {
|
if (assertion.getSubject() == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
if (assertion.getSubject().getEncryptedID() == null) {
|
if (assertion.getSubject().getEncryptedID() == null) {
|
||||||
return assertion.getSubject().getNameID();
|
return assertion.getSubject().getNameID();
|
||||||
}
|
}
|
||||||
try {
|
this.principalDecrypter.accept(responseToken);
|
||||||
NameID nameId = (NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID());
|
return assertion.getSubject().getNameID();
|
||||||
assertion.getSubject().setNameID(nameId);
|
|
||||||
return nameId;
|
|
||||||
}
|
|
||||||
catch (DecryptionException ex) {
|
|
||||||
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
|
private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
|
||||||
|
|
|
@ -56,6 +56,7 @@ import org.springframework.security.saml2.core.Saml2Error;
|
||||||
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
|
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
|
||||||
import org.springframework.security.saml2.credentials.Saml2X509Credential;
|
import org.springframework.security.saml2.credentials.Saml2X509Credential;
|
||||||
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
|
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
|
||||||
|
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.ResponseToken;
|
||||||
import org.springframework.util.StringUtils;
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
@ -446,8 +447,7 @@ public class OpenSamlAuthenticationProviderTests {
|
||||||
public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() {
|
public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() {
|
||||||
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
|
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
|
||||||
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
|
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
|
||||||
OpenSamlAuthenticationProvider.ResponseToken responseToken = new OpenSamlAuthenticationProvider.ResponseToken(
|
ResponseToken responseToken = new ResponseToken(response, token);
|
||||||
response, token);
|
|
||||||
Saml2Authentication authentication = OpenSamlAuthenticationProvider
|
Saml2Authentication authentication = OpenSamlAuthenticationProvider
|
||||||
.createDefaultResponseAuthenticationConverter().convert(responseToken);
|
.createDefaultResponseAuthenticationConverter().convert(responseToken);
|
||||||
assertThat(authentication.getName()).isEqualTo("test@saml.user");
|
assertThat(authentication.getName()).isEqualTo("test@saml.user");
|
||||||
|
@ -455,8 +455,7 @@ public class OpenSamlAuthenticationProviderTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() {
|
public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() {
|
||||||
Converter<OpenSamlAuthenticationProvider.ResponseToken, Saml2Authentication> authenticationConverter = mock(
|
Converter<ResponseToken, Saml2Authentication> authenticationConverter = mock(Converter.class);
|
||||||
Converter.class);
|
|
||||||
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
|
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
|
||||||
provider.setResponseAuthenticationConverter(authenticationConverter);
|
provider.setResponseAuthenticationConverter(authenticationConverter);
|
||||||
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
|
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
|
||||||
|
@ -473,6 +472,57 @@ public class OpenSamlAuthenticationProviderTests {
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setAssertionDecrypterWhenNullThenIllegalArgument() {
|
||||||
|
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionDecrypter(null));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setPrincipalDecrypterWhenNullThenIllegalArgument() {
|
||||||
|
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setPrincipalDecrypter(null));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setAssertionDecrypterThenChangesAssertion() {
|
||||||
|
Response response = TestOpenSamlObjects.response();
|
||||||
|
Assertion assertion = TestOpenSamlObjects.assertion();
|
||||||
|
assertion.getSubject().getSubjectConfirmations()
|
||||||
|
.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
|
||||||
|
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
|
||||||
|
RELYING_PARTY_ENTITY_ID);
|
||||||
|
response.getAssertions().add(assertion);
|
||||||
|
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
|
||||||
|
this.provider.setAssertionDecrypter(mockAssertionAndPrincipalDecrypter());
|
||||||
|
assertThatExceptionOfType(Saml2AuthenticationException.class)
|
||||||
|
.isThrownBy(() -> this.provider.authenticate(token))
|
||||||
|
.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
|
||||||
|
assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setPrincipalDecrypterThenChangesAssertion() {
|
||||||
|
Response response = TestOpenSamlObjects.response();
|
||||||
|
Assertion assertion = TestOpenSamlObjects.assertion();
|
||||||
|
assertion.getSubject().getSubjectConfirmations()
|
||||||
|
.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
|
||||||
|
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
|
||||||
|
RELYING_PARTY_ENTITY_ID);
|
||||||
|
response.getAssertions().add(assertion);
|
||||||
|
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
|
||||||
|
this.provider.setPrincipalDecrypter(mockAssertionAndPrincipalDecrypter());
|
||||||
|
this.provider.authenticate(token);
|
||||||
|
assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Consumer<ResponseToken> mockAssertionAndPrincipalDecrypter() {
|
||||||
|
return (responseToken) -> {
|
||||||
|
responseToken.getResponse().getAssertions().clear();
|
||||||
|
responseToken.getResponse().getAssertions()
|
||||||
|
.add(TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion("1", "2", "3", "4"),
|
||||||
|
TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID));
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
private <T extends XMLObject> T build(QName qName) {
|
private <T extends XMLObject> T build(QName qName) {
|
||||||
return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
|
return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue