Update to return List of StatusCodes and add Saml2Error to result object and other formatting
This commit is contained in:
parent
76331a5653
commit
6e45e65cac
|
@ -20,16 +20,15 @@ import java.io.ByteArrayInputStream;
|
|||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import java.util.Arrays;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import javax.annotation.Nonnull;
|
||||
|
@ -98,8 +97,6 @@ import org.springframework.util.LinkedMultiValueMap;
|
|||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import static org.opensaml.saml.saml2.core.StatusCode.*;
|
||||
|
||||
/**
|
||||
* Implementation of {@link AuthenticationProvider} for SAML authentications when
|
||||
* receiving a {@code Response} object containing an {@code Assertion}. This
|
||||
|
@ -174,7 +171,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
|
|||
|
||||
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
|
||||
|
||||
private static final Set<String> includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH));
|
||||
private static final Set<String> includeChildStatusCodes = new HashSet<>(
|
||||
Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH));
|
||||
|
||||
/**
|
||||
* Creates an {@link OpenSaml4AuthenticationProvider}
|
||||
|
@ -379,11 +377,13 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
|
|||
Response response = responseToken.getResponse();
|
||||
Saml2AuthenticationToken token = responseToken.getToken();
|
||||
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
|
||||
String statusCode = getStatusCode(response);
|
||||
if (!StatusCode.SUCCESS.equals(statusCode)) {
|
||||
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
|
||||
response.getID());
|
||||
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
|
||||
List<String> statusCodes = getStatusCodes(response);
|
||||
if (!isSuccess(statusCodes)) {
|
||||
for (String statusCode : statusCodes) {
|
||||
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
|
||||
response.getID());
|
||||
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
|
||||
}
|
||||
}
|
||||
|
||||
String inResponseTo = response.getInResponseTo();
|
||||
|
@ -412,24 +412,37 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
|
|||
};
|
||||
}
|
||||
|
||||
private static String getStatusCode(Response response) {
|
||||
private static List<String> getStatusCodes(Response response) {
|
||||
if (response.getStatus() == null) {
|
||||
return StatusCode.SUCCESS;
|
||||
return Arrays.asList(StatusCode.SUCCESS);
|
||||
}
|
||||
if (response.getStatus().getStatusCode() == null) {
|
||||
return StatusCode.SUCCESS;
|
||||
return Arrays.asList(StatusCode.SUCCESS);
|
||||
}
|
||||
|
||||
StatusCode parentStatusCode = response.getStatus().getStatusCode();
|
||||
String parentStatusCodeValue = parentStatusCode.getValue();
|
||||
if (includeChildStatusCodes.contains(parentStatusCodeValue)) {
|
||||
return Optional.ofNullable(parentStatusCode.getStatusCode())
|
||||
.map(StatusCode::getValue)
|
||||
.map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue)
|
||||
.orElse(parentStatusCodeValue);
|
||||
StatusCode statusCode = parentStatusCode.getStatusCode();
|
||||
if (statusCode != null) {
|
||||
String childStatusCodeValue = statusCode.getValue();
|
||||
if (childStatusCodeValue != null) {
|
||||
return Arrays.asList(parentStatusCodeValue, childStatusCodeValue);
|
||||
}
|
||||
}
|
||||
return Arrays.asList(parentStatusCodeValue);
|
||||
}
|
||||
|
||||
return parentStatusCodeValue;
|
||||
return Arrays.asList(parentStatusCodeValue);
|
||||
}
|
||||
|
||||
private static boolean isSuccess(List<String> statusCodes) {
|
||||
if (statusCodes.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
String statusCode = statusCodes.get(0);
|
||||
return StatusCode.SUCCESS.equals(statusCode);
|
||||
}
|
||||
|
||||
private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
* Copyright 2002-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -86,8 +86,6 @@ import org.springframework.util.StringUtils;
|
|||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.BDDMockito.given;
|
||||
import static org.mockito.Mockito.atLeastOnce;
|
||||
|
@ -736,7 +734,7 @@ public class OpenSaml4AuthenticationProviderTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void setsOnlyParentStatusCodeOnResultDescription() {
|
||||
public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() {
|
||||
ResponseToken mockResponseToken = mock(ResponseToken.class);
|
||||
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
|
||||
given(mockResponseToken.getToken()).willReturn(mockSamlToken);
|
||||
|
@ -744,7 +742,8 @@ public class OpenSaml4AuthenticationProviderTests {
|
|||
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
|
||||
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
|
||||
|
||||
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
|
||||
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
|
||||
RelyingPartyRegistration.AssertingPartyDetails.class);
|
||||
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
|
||||
|
||||
Status parentStatus = new StatusBuilder().buildObject();
|
||||
|
@ -763,16 +762,21 @@ public class OpenSaml4AuthenticationProviderTests {
|
|||
|
||||
given(mockResponseToken.getResponse()).willReturn(mockResponse);
|
||||
|
||||
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
|
||||
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
|
||||
.createDefaultResponseValidator();
|
||||
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
|
||||
|
||||
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue());
|
||||
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
|
||||
assertFalse(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(childStatusCode.getValue())));
|
||||
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response",
|
||||
parentStatusCode.getValue());
|
||||
assertThat(
|
||||
result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage)));
|
||||
assertThat(result.getErrors()
|
||||
.stream()
|
||||
.noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue())));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setsParentAndChildStatusCodeOnResultDescription() {
|
||||
public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() {
|
||||
ResponseToken mockResponseToken = mock(ResponseToken.class);
|
||||
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
|
||||
given(mockResponseToken.getToken()).willReturn(mockSamlToken);
|
||||
|
@ -780,7 +784,8 @@ public class OpenSaml4AuthenticationProviderTests {
|
|||
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
|
||||
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
|
||||
|
||||
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
|
||||
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
|
||||
RelyingPartyRegistration.AssertingPartyDetails.class);
|
||||
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
|
||||
|
||||
Status parentStatus = new StatusBuilder().buildObject();
|
||||
|
@ -799,11 +804,20 @@ public class OpenSaml4AuthenticationProviderTests {
|
|||
|
||||
given(mockResponseToken.getResponse()).willReturn(mockResponse);
|
||||
|
||||
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
|
||||
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
|
||||
.createDefaultResponseValidator();
|
||||
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
|
||||
|
||||
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue() + childStatusCode.getValue());
|
||||
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
|
||||
String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response",
|
||||
parentStatusCode.getValue());
|
||||
String expectedChildErrorMessage = String.format("Invalid status [%s] for SAML response",
|
||||
childStatusCode.getValue());
|
||||
assertThat(result.getErrors()
|
||||
.stream()
|
||||
.anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage)));
|
||||
assertThat(result.getErrors()
|
||||
.stream()
|
||||
.anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage)));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue