Use Optional in case child status code is null

This commit is contained in:
youngkih 2024-03-16 21:58:26 +09:00 committed by Josh Cummings
parent 01e2971085
commit 994e064412
1 changed files with 23 additions and 20 deletions

View File

@ -29,6 +29,7 @@ import java.util.Map;
import java.util.Set;
import java.util.HashSet;
import java.util.Arrays;
import java.util.Optional;
import java.util.function.Consumer;
import javax.annotation.Nonnull;
@ -173,6 +174,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
private static final Set<String> includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH));
/**
* Creates an {@link OpenSaml4AuthenticationProvider}
*/
@ -409,6 +412,26 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
};
}
private static String getStatusCode(Response response) {
if (response.getStatus() == null) {
return StatusCode.SUCCESS;
}
if (response.getStatus().getStatusCode() == null) {
return StatusCode.SUCCESS;
}
StatusCode parentStatusCode = response.getStatus().getStatusCode();
String parentStatusCodeValue = parentStatusCode.getValue();
if (includeChildStatusCodes.contains(parentStatusCodeValue)) {
return Optional.ofNullable(parentStatusCode.getStatusCode())
.map(StatusCode::getValue)
.map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue)
.orElse(parentStatusCodeValue);
}
return parentStatusCodeValue;
}
private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
String inResponseTo) {
if (!StringUtils.hasText(inResponseTo)) {
@ -619,26 +642,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
};
}
private static String getStatusCode(Response response) {
if (response.getStatus() == null) {
return StatusCode.SUCCESS;
}
if (response.getStatus().getStatusCode() == null) {
return StatusCode.SUCCESS;
}
Set<String> statusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH));
StatusCode parentStatusCode = response.getStatus().getStatusCode();
String parentStatusCodeValue = parentStatusCode.getValue();
if (statusCodes.contains(parentStatusCodeValue)) {
StatusCode childStatusCode = parentStatusCode.getStatusCode();
String childStatusCodeValue = childStatusCode.getValue();
return parentStatusCodeValue + childStatusCodeValue;
}
return parentStatusCodeValue;
}
private Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionSignatureValidator() {
return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> {
RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration();