mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-03-09 06:50:05 +00:00
Allow Defining Custom SAML Response Validator
Add a setter method into OpenSaml4AuthenticationProvider that allows defining a custom ResponseValidator Closes gh-9721
This commit is contained in:
parent
6474a9e76e
commit
03ded987af
@ -1271,8 +1271,29 @@ It's not required to call `OpenSaml4AuthenticationProvider` 's default authentic
|
||||
It returns a `Saml2AuthenticatedPrincipal` containing the attributes it extracted from `AttributeStatement` s as well as the single `ROLE_USER` authority.
|
||||
|
||||
[[servlet-saml2login-opensamlauthenticationprovider-additionalvalidation]]
|
||||
==== Performing Additional Validation
|
||||
==== Performing Additional Response Validation
|
||||
|
||||
`OpenSaml4AuthenticationProvider` validates the `Issuer` and `Destination` values right after decrypting the `Response`.
|
||||
You can customize the validation by extending the default validator concatenating with your own response validator, or you can replace it entirely with yours.
|
||||
|
||||
For example, you can throw a custom exception with any additional information available in the `Response` object, like so:
|
||||
[source,java]
|
||||
----
|
||||
OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider();
|
||||
provider.setResponseValidator((responseToken) -> {
|
||||
Saml2ResponseValidatorResult result = OpenSamlAuthenticationProvider
|
||||
.createDefaultResponseValidator()
|
||||
.convert(responseToken)
|
||||
.concat(myCustomValidator.convert(responseToken));
|
||||
if (!result.getErrors().isEmpty()) {
|
||||
String inResponseTo = responseToken.getInResponseTo();
|
||||
throw new CustomSaml2AuthenticationException(result, inResponseTo);
|
||||
}
|
||||
return result;
|
||||
});
|
||||
----
|
||||
|
||||
==== Performing Additional Assertion Validation
|
||||
`OpenSaml4AuthenticationProvider` performs minimal validation on SAML 2.0 Assertions.
|
||||
After verifying the signature, it will:
|
||||
|
||||
|
@ -145,7 +145,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
|
||||
|
||||
private Consumer<ResponseToken> responseElementsDecrypter = createDefaultResponseElementsDecrypter();
|
||||
|
||||
private final Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator = createDefaultResponseValidator();
|
||||
private Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator = createDefaultResponseValidator();
|
||||
|
||||
private final Converter<AssertionToken, Saml2ResponseValidatorResult> assertionSignatureValidator = createDefaultAssertionSignatureValidator();
|
||||
|
||||
@ -213,6 +213,28 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
|
||||
this.responseElementsDecrypter = responseElementsDecrypter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the {@link Converter} to use for validating the SAML 2.0 Response.
|
||||
*
|
||||
* You can still invoke the default validator by delegating to
|
||||
* {@link #createDefaultResponseValidator()}, like so:
|
||||
*
|
||||
* <pre>
|
||||
* OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider();
|
||||
* provider.setResponseValidator(responseToken -> {
|
||||
* Saml2ResponseValidatorResult result = createDefaultResponseValidator()
|
||||
* .convert(responseToken)
|
||||
* return result.concat(myCustomValidator.convert(responseToken));
|
||||
* });
|
||||
* </pre>
|
||||
* @param responseValidator the {@link Converter} to use
|
||||
* @since 5.6
|
||||
*/
|
||||
public void setResponseValidator(Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator) {
|
||||
Assert.notNull(responseValidator, "responseValidator cannot be null");
|
||||
this.responseValidator = responseValidator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the {@link Converter} to use for validating each {@link Assertion} in the SAML
|
||||
* 2.0 Response.
|
||||
@ -326,6 +348,44 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
|
||||
this.responseAuthenticationConverter = responseAuthenticationConverter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a default strategy for validating the SAML 2.0 Response
|
||||
* @return the default response validator strategy
|
||||
* @since 5.6
|
||||
*/
|
||||
public static Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseValidator() {
|
||||
return (responseToken) -> {
|
||||
Response response = responseToken.getResponse();
|
||||
Saml2AuthenticationToken token = responseToken.getToken();
|
||||
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
|
||||
String statusCode = getStatusCode(response);
|
||||
if (!StatusCode.SUCCESS.equals(statusCode)) {
|
||||
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
|
||||
response.getID());
|
||||
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
|
||||
}
|
||||
String issuer = response.getIssuer().getValue();
|
||||
String destination = response.getDestination();
|
||||
String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
|
||||
if (StringUtils.hasText(destination) && !destination.equals(location)) {
|
||||
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID()
|
||||
+ "]";
|
||||
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
|
||||
}
|
||||
String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails()
|
||||
.getEntityId();
|
||||
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
|
||||
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
|
||||
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
|
||||
}
|
||||
if (response.getAssertions().isEmpty()) {
|
||||
throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
|
||||
"No assertions found in response.", null);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a default strategy for validating each SAML 2.0 Assertion and associated
|
||||
* {@link Authentication} token
|
||||
@ -487,40 +547,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
|
||||
};
|
||||
}
|
||||
|
||||
private Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseValidator() {
|
||||
return (responseToken) -> {
|
||||
Response response = responseToken.getResponse();
|
||||
Saml2AuthenticationToken token = responseToken.getToken();
|
||||
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
|
||||
String statusCode = getStatusCode(response);
|
||||
if (!StatusCode.SUCCESS.equals(statusCode)) {
|
||||
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
|
||||
response.getID());
|
||||
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
|
||||
}
|
||||
String issuer = response.getIssuer().getValue();
|
||||
String destination = response.getDestination();
|
||||
String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
|
||||
if (StringUtils.hasText(destination) && !destination.equals(location)) {
|
||||
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID()
|
||||
+ "]";
|
||||
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
|
||||
}
|
||||
String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails()
|
||||
.getEntityId();
|
||||
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
|
||||
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
|
||||
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
|
||||
}
|
||||
if (response.getAssertions().isEmpty()) {
|
||||
throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
|
||||
"No assertions found in response.", null);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
private String getStatusCode(Response response) {
|
||||
private static String getStatusCode(Response response) {
|
||||
if (response.getStatus() == null) {
|
||||
return StatusCode.SUCCESS;
|
||||
}
|
||||
|
@ -585,6 +585,34 @@ public class OpenSaml4AuthenticationProviderTests {
|
||||
assertThat(authentication.getName()).isEqualTo("test@saml.user");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setResponseValidatorWhenNullThenIllegalArgument() {
|
||||
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setResponseValidator(null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authenticateWhenCustomResponseValidatorThenUses() {
|
||||
Converter<OpenSaml4AuthenticationProvider.ResponseToken, Saml2ResponseValidatorResult> validator = mock(
|
||||
Converter.class);
|
||||
OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider();
|
||||
// @formatter:off
|
||||
provider.setResponseValidator((responseToken) -> OpenSaml4AuthenticationProvider.createDefaultResponseValidator()
|
||||
.convert(responseToken)
|
||||
.concat(validator.convert(responseToken))
|
||||
);
|
||||
// @formatter:on
|
||||
Response response = response();
|
||||
Assertion assertion = assertion();
|
||||
response.getAssertions().add(assertion);
|
||||
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
|
||||
ASSERTING_PARTY_ENTITY_ID);
|
||||
Saml2AuthenticationToken token = token(response, verifying(registration()));
|
||||
given(validator.convert(any(OpenSaml4AuthenticationProvider.ResponseToken.class)))
|
||||
.willReturn(Saml2ResponseValidatorResult.success());
|
||||
provider.authenticate(token);
|
||||
verify(validator).convert(any(OpenSaml4AuthenticationProvider.ResponseToken.class));
|
||||
}
|
||||
|
||||
private <T extends XMLObject> T build(QName qName) {
|
||||
return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user