diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index 91c0ef43b0..5ca2bbcc5c 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -29,6 +29,7 @@ import java.util.Map; import java.util.Set; import java.util.HashSet; import java.util.Arrays; +import java.util.Optional; import java.util.function.Consumer; import javax.annotation.Nonnull; @@ -173,6 +174,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv private Converter responseAuthenticationConverter = createDefaultResponseAuthenticationConverter(); + private static final Set includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH)); + /** * Creates an {@link OpenSaml4AuthenticationProvider} */ @@ -409,6 +412,26 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv }; } + private static String getStatusCode(Response response) { + if (response.getStatus() == null) { + return StatusCode.SUCCESS; + } + if (response.getStatus().getStatusCode() == null) { + return StatusCode.SUCCESS; + } + + StatusCode parentStatusCode = response.getStatus().getStatusCode(); + String parentStatusCodeValue = parentStatusCode.getValue(); + if (includeChildStatusCodes.contains(parentStatusCodeValue)) { + return Optional.ofNullable(parentStatusCode.getStatusCode()) + .map(StatusCode::getValue) + .map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue) + .orElse(parentStatusCodeValue); + } + + return parentStatusCodeValue; + } + private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest, String inResponseTo) { if (!StringUtils.hasText(inResponseTo)) { @@ -619,26 +642,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv }; } - private static String getStatusCode(Response response) { - if (response.getStatus() == null) { - return StatusCode.SUCCESS; - } - if (response.getStatus().getStatusCode() == null) { - return StatusCode.SUCCESS; - } - - Set statusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH)); - StatusCode parentStatusCode = response.getStatus().getStatusCode(); - String parentStatusCodeValue = parentStatusCode.getValue(); - if (statusCodes.contains(parentStatusCodeValue)) { - StatusCode childStatusCode = parentStatusCode.getStatusCode(); - String childStatusCodeValue = childStatusCode.getValue(); - return parentStatusCodeValue + childStatusCodeValue; - } - - return parentStatusCodeValue; - } - private Converter createDefaultAssertionSignatureValidator() { return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> { RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration();