Update to return List of StatusCodes and add Saml2Error to result object and other formatting

This commit is contained in:
YoungKi Hong 2024-03-21 23:56:03 +09:00 committed by Josh Cummings
parent 76331a5653
commit 6e45e65cac
2 changed files with 60 additions and 33 deletions

View File

@ -20,16 +20,15 @@ import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.HashSet;
import java.util.Arrays;
import java.util.Optional;
import java.util.function.Consumer; import java.util.function.Consumer;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
@ -98,8 +97,6 @@ import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import static org.opensaml.saml.saml2.core.StatusCode.*;
/** /**
* Implementation of {@link AuthenticationProvider} for SAML authentications when * Implementation of {@link AuthenticationProvider} for SAML authentications when
* receiving a {@code Response} object containing an {@code Assertion}. This * receiving a {@code Response} object containing an {@code Assertion}. This
@ -174,7 +171,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter(); private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
private static final Set<String> includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH)); private static final Set<String> includeChildStatusCodes = new HashSet<>(
Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH));
/** /**
* Creates an {@link OpenSaml4AuthenticationProvider} * Creates an {@link OpenSaml4AuthenticationProvider}
@ -379,11 +377,13 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
Response response = responseToken.getResponse(); Response response = responseToken.getResponse();
Saml2AuthenticationToken token = responseToken.getToken(); Saml2AuthenticationToken token = responseToken.getToken();
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success(); Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
String statusCode = getStatusCode(response); List<String> statusCodes = getStatusCodes(response);
if (!StatusCode.SUCCESS.equals(statusCode)) { if (!isSuccess(statusCodes)) {
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode, for (String statusCode : statusCodes) {
response.getID()); String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); response.getID());
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
}
} }
String inResponseTo = response.getInResponseTo(); String inResponseTo = response.getInResponseTo();
@ -412,24 +412,37 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
}; };
} }
private static String getStatusCode(Response response) { private static List<String> getStatusCodes(Response response) {
if (response.getStatus() == null) { if (response.getStatus() == null) {
return StatusCode.SUCCESS; return Arrays.asList(StatusCode.SUCCESS);
} }
if (response.getStatus().getStatusCode() == null) { if (response.getStatus().getStatusCode() == null) {
return StatusCode.SUCCESS; return Arrays.asList(StatusCode.SUCCESS);
} }
StatusCode parentStatusCode = response.getStatus().getStatusCode(); StatusCode parentStatusCode = response.getStatus().getStatusCode();
String parentStatusCodeValue = parentStatusCode.getValue(); String parentStatusCodeValue = parentStatusCode.getValue();
if (includeChildStatusCodes.contains(parentStatusCodeValue)) { if (includeChildStatusCodes.contains(parentStatusCodeValue)) {
return Optional.ofNullable(parentStatusCode.getStatusCode()) StatusCode statusCode = parentStatusCode.getStatusCode();
.map(StatusCode::getValue) if (statusCode != null) {
.map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue) String childStatusCodeValue = statusCode.getValue();
.orElse(parentStatusCodeValue); if (childStatusCodeValue != null) {
return Arrays.asList(parentStatusCodeValue, childStatusCodeValue);
}
}
return Arrays.asList(parentStatusCodeValue);
} }
return parentStatusCodeValue; return Arrays.asList(parentStatusCodeValue);
}
private static boolean isSuccess(List<String> statusCodes) {
if (statusCodes.size() != 1) {
return false;
}
String statusCode = statusCodes.get(0);
return StatusCode.SUCCESS.equals(statusCode);
} }
private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest, private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2023 the original author or authors. * Copyright 2002-2024 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -86,8 +86,6 @@ import org.springframework.util.StringUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
@ -736,7 +734,7 @@ public class OpenSaml4AuthenticationProviderTests {
} }
@Test @Test
public void setsOnlyParentStatusCodeOnResultDescription() { public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() {
ResponseToken mockResponseToken = mock(ResponseToken.class); ResponseToken mockResponseToken = mock(ResponseToken.class);
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class); Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
given(mockResponseToken.getToken()).willReturn(mockSamlToken); given(mockResponseToken.getToken()).willReturn(mockSamlToken);
@ -744,7 +742,8 @@ public class OpenSaml4AuthenticationProviderTests {
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class); RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration); given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class); RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
RelyingPartyRegistration.AssertingPartyDetails.class);
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails); given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
Status parentStatus = new StatusBuilder().buildObject(); Status parentStatus = new StatusBuilder().buildObject();
@ -763,16 +762,21 @@ public class OpenSaml4AuthenticationProviderTests {
given(mockResponseToken.getResponse()).willReturn(mockResponse); given(mockResponseToken.getResponse()).willReturn(mockResponse);
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator(); Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
.createDefaultResponseValidator();
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken); Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue()); String expectedErrorMessage = String.format("Invalid status [%s] for SAML response",
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage))); parentStatusCode.getValue());
assertFalse(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(childStatusCode.getValue()))); assertThat(
result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage)));
assertThat(result.getErrors()
.stream()
.noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue())));
} }
@Test @Test
public void setsParentAndChildStatusCodeOnResultDescription() { public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() {
ResponseToken mockResponseToken = mock(ResponseToken.class); ResponseToken mockResponseToken = mock(ResponseToken.class);
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class); Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
given(mockResponseToken.getToken()).willReturn(mockSamlToken); given(mockResponseToken.getToken()).willReturn(mockSamlToken);
@ -780,7 +784,8 @@ public class OpenSaml4AuthenticationProviderTests {
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class); RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration); given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class); RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
RelyingPartyRegistration.AssertingPartyDetails.class);
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails); given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
Status parentStatus = new StatusBuilder().buildObject(); Status parentStatus = new StatusBuilder().buildObject();
@ -799,11 +804,20 @@ public class OpenSaml4AuthenticationProviderTests {
given(mockResponseToken.getResponse()).willReturn(mockResponse); given(mockResponseToken.getResponse()).willReturn(mockResponse);
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator(); Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
.createDefaultResponseValidator();
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken); Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue() + childStatusCode.getValue()); String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response",
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage))); parentStatusCode.getValue());
String expectedChildErrorMessage = String.format("Invalid status [%s] for SAML response",
childStatusCode.getValue());
assertThat(result.getErrors()
.stream()
.anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage)));
assertThat(result.getErrors()
.stream()
.anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage)));
} }
@Test @Test