Update to return List of StatusCodes and add Saml2Error to result object and other formatting

This commit is contained in:
YoungKi Hong 2024-03-21 23:56:03 +09:00 committed by Josh Cummings
parent 76331a5653
commit 6e45e65cac
2 changed files with 60 additions and 33 deletions

View File

@ -20,16 +20,15 @@ import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.HashSet;
import java.util.Arrays;
import java.util.Optional;
import java.util.function.Consumer;
import javax.annotation.Nonnull;
@ -98,8 +97,6 @@ import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import static org.opensaml.saml.saml2.core.StatusCode.*;
/**
* Implementation of {@link AuthenticationProvider} for SAML authentications when
* receiving a {@code Response} object containing an {@code Assertion}. This
@ -174,7 +171,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
private static final Set<String> includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH));
private static final Set<String> includeChildStatusCodes = new HashSet<>(
Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH));
/**
* Creates an {@link OpenSaml4AuthenticationProvider}
@ -379,11 +377,13 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
Response response = responseToken.getResponse();
Saml2AuthenticationToken token = responseToken.getToken();
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
String statusCode = getStatusCode(response);
if (!StatusCode.SUCCESS.equals(statusCode)) {
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
response.getID());
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
List<String> statusCodes = getStatusCodes(response);
if (!isSuccess(statusCodes)) {
for (String statusCode : statusCodes) {
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
response.getID());
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
}
}
String inResponseTo = response.getInResponseTo();
@ -412,24 +412,37 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
};
}
private static String getStatusCode(Response response) {
private static List<String> getStatusCodes(Response response) {
if (response.getStatus() == null) {
return StatusCode.SUCCESS;
return Arrays.asList(StatusCode.SUCCESS);
}
if (response.getStatus().getStatusCode() == null) {
return StatusCode.SUCCESS;
return Arrays.asList(StatusCode.SUCCESS);
}
StatusCode parentStatusCode = response.getStatus().getStatusCode();
String parentStatusCodeValue = parentStatusCode.getValue();
if (includeChildStatusCodes.contains(parentStatusCodeValue)) {
return Optional.ofNullable(parentStatusCode.getStatusCode())
.map(StatusCode::getValue)
.map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue)
.orElse(parentStatusCodeValue);
StatusCode statusCode = parentStatusCode.getStatusCode();
if (statusCode != null) {
String childStatusCodeValue = statusCode.getValue();
if (childStatusCodeValue != null) {
return Arrays.asList(parentStatusCodeValue, childStatusCodeValue);
}
}
return Arrays.asList(parentStatusCodeValue);
}
return parentStatusCodeValue;
return Arrays.asList(parentStatusCodeValue);
}
private static boolean isSuccess(List<String> statusCodes) {
if (statusCodes.size() != 1) {
return false;
}
String statusCode = statusCodes.get(0);
return StatusCode.SUCCESS.equals(statusCode);
}
private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -86,8 +86,6 @@ import org.springframework.util.StringUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
@ -736,7 +734,7 @@ public class OpenSaml4AuthenticationProviderTests {
}
@Test
public void setsOnlyParentStatusCodeOnResultDescription() {
public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() {
ResponseToken mockResponseToken = mock(ResponseToken.class);
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
given(mockResponseToken.getToken()).willReturn(mockSamlToken);
@ -744,7 +742,8 @@ public class OpenSaml4AuthenticationProviderTests {
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
RelyingPartyRegistration.AssertingPartyDetails.class);
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
Status parentStatus = new StatusBuilder().buildObject();
@ -763,16 +762,21 @@ public class OpenSaml4AuthenticationProviderTests {
given(mockResponseToken.getResponse()).willReturn(mockResponse);
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
.createDefaultResponseValidator();
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue());
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
assertFalse(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(childStatusCode.getValue())));
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response",
parentStatusCode.getValue());
assertThat(
result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage)));
assertThat(result.getErrors()
.stream()
.noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue())));
}
@Test
public void setsParentAndChildStatusCodeOnResultDescription() {
public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() {
ResponseToken mockResponseToken = mock(ResponseToken.class);
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
given(mockResponseToken.getToken()).willReturn(mockSamlToken);
@ -780,7 +784,8 @@ public class OpenSaml4AuthenticationProviderTests {
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
RelyingPartyRegistration.AssertingPartyDetails.class);
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
Status parentStatus = new StatusBuilder().buildObject();
@ -799,11 +804,20 @@ public class OpenSaml4AuthenticationProviderTests {
given(mockResponseToken.getResponse()).willReturn(mockResponse);
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
.createDefaultResponseValidator();
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue() + childStatusCode.getValue());
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response",
parentStatusCode.getValue());
String expectedChildErrorMessage = String.format("Invalid status [%s] for SAML response",
childStatusCode.getValue());
assertThat(result.getErrors()
.stream()
.anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage)));
assertThat(result.getErrors()
.stream()
.anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage)));
}
@Test