mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-06-28 06:42:49 +00:00
Update to return List of StatusCodes and add Saml2Error to result object and other formatting
This commit is contained in:
parent
76331a5653
commit
6e45e65cac
@ -20,16 +20,15 @@ import java.io.ByteArrayInputStream;
|
|||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
import javax.annotation.Nonnull;
|
import javax.annotation.Nonnull;
|
||||||
@ -98,8 +97,6 @@ import org.springframework.util.LinkedMultiValueMap;
|
|||||||
import org.springframework.util.MultiValueMap;
|
import org.springframework.util.MultiValueMap;
|
||||||
import org.springframework.util.StringUtils;
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
import static org.opensaml.saml.saml2.core.StatusCode.*;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implementation of {@link AuthenticationProvider} for SAML authentications when
|
* Implementation of {@link AuthenticationProvider} for SAML authentications when
|
||||||
* receiving a {@code Response} object containing an {@code Assertion}. This
|
* receiving a {@code Response} object containing an {@code Assertion}. This
|
||||||
@ -174,7 +171,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
|
|||||||
|
|
||||||
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
|
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
|
||||||
|
|
||||||
private static final Set<String> includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH));
|
private static final Set<String> includeChildStatusCodes = new HashSet<>(
|
||||||
|
Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an {@link OpenSaml4AuthenticationProvider}
|
* Creates an {@link OpenSaml4AuthenticationProvider}
|
||||||
@ -379,12 +377,14 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
|
|||||||
Response response = responseToken.getResponse();
|
Response response = responseToken.getResponse();
|
||||||
Saml2AuthenticationToken token = responseToken.getToken();
|
Saml2AuthenticationToken token = responseToken.getToken();
|
||||||
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
|
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
|
||||||
String statusCode = getStatusCode(response);
|
List<String> statusCodes = getStatusCodes(response);
|
||||||
if (!StatusCode.SUCCESS.equals(statusCode)) {
|
if (!isSuccess(statusCodes)) {
|
||||||
|
for (String statusCode : statusCodes) {
|
||||||
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
|
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
|
||||||
response.getID());
|
response.getID());
|
||||||
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
|
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
String inResponseTo = response.getInResponseTo();
|
String inResponseTo = response.getInResponseTo();
|
||||||
result = result.concat(validateInResponseTo(token.getAuthenticationRequest(), inResponseTo));
|
result = result.concat(validateInResponseTo(token.getAuthenticationRequest(), inResponseTo));
|
||||||
@ -412,24 +412,37 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private static String getStatusCode(Response response) {
|
private static List<String> getStatusCodes(Response response) {
|
||||||
if (response.getStatus() == null) {
|
if (response.getStatus() == null) {
|
||||||
return StatusCode.SUCCESS;
|
return Arrays.asList(StatusCode.SUCCESS);
|
||||||
}
|
}
|
||||||
if (response.getStatus().getStatusCode() == null) {
|
if (response.getStatus().getStatusCode() == null) {
|
||||||
return StatusCode.SUCCESS;
|
return Arrays.asList(StatusCode.SUCCESS);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusCode parentStatusCode = response.getStatus().getStatusCode();
|
StatusCode parentStatusCode = response.getStatus().getStatusCode();
|
||||||
String parentStatusCodeValue = parentStatusCode.getValue();
|
String parentStatusCodeValue = parentStatusCode.getValue();
|
||||||
if (includeChildStatusCodes.contains(parentStatusCodeValue)) {
|
if (includeChildStatusCodes.contains(parentStatusCodeValue)) {
|
||||||
return Optional.ofNullable(parentStatusCode.getStatusCode())
|
StatusCode statusCode = parentStatusCode.getStatusCode();
|
||||||
.map(StatusCode::getValue)
|
if (statusCode != null) {
|
||||||
.map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue)
|
String childStatusCodeValue = statusCode.getValue();
|
||||||
.orElse(parentStatusCodeValue);
|
if (childStatusCodeValue != null) {
|
||||||
|
return Arrays.asList(parentStatusCodeValue, childStatusCodeValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Arrays.asList(parentStatusCodeValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
return parentStatusCodeValue;
|
return Arrays.asList(parentStatusCodeValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean isSuccess(List<String> statusCodes) {
|
||||||
|
if (statusCodes.size() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
String statusCode = statusCodes.get(0);
|
||||||
|
return StatusCode.SUCCESS.equals(statusCode);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
|
private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright 2002-2023 the original author or authors.
|
* Copyright 2002-2024 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@ -86,8 +86,6 @@ import org.springframework.util.StringUtils;
|
|||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||||
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
||||||
import static org.junit.Assert.assertFalse;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.BDDMockito.given;
|
import static org.mockito.BDDMockito.given;
|
||||||
import static org.mockito.Mockito.atLeastOnce;
|
import static org.mockito.Mockito.atLeastOnce;
|
||||||
@ -736,7 +734,7 @@ public class OpenSaml4AuthenticationProviderTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void setsOnlyParentStatusCodeOnResultDescription() {
|
public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() {
|
||||||
ResponseToken mockResponseToken = mock(ResponseToken.class);
|
ResponseToken mockResponseToken = mock(ResponseToken.class);
|
||||||
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
|
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
|
||||||
given(mockResponseToken.getToken()).willReturn(mockSamlToken);
|
given(mockResponseToken.getToken()).willReturn(mockSamlToken);
|
||||||
@ -744,7 +742,8 @@ public class OpenSaml4AuthenticationProviderTests {
|
|||||||
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
|
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
|
||||||
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
|
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
|
||||||
|
|
||||||
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
|
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
|
||||||
|
RelyingPartyRegistration.AssertingPartyDetails.class);
|
||||||
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
|
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
|
||||||
|
|
||||||
Status parentStatus = new StatusBuilder().buildObject();
|
Status parentStatus = new StatusBuilder().buildObject();
|
||||||
@ -763,16 +762,21 @@ public class OpenSaml4AuthenticationProviderTests {
|
|||||||
|
|
||||||
given(mockResponseToken.getResponse()).willReturn(mockResponse);
|
given(mockResponseToken.getResponse()).willReturn(mockResponse);
|
||||||
|
|
||||||
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
|
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
|
||||||
|
.createDefaultResponseValidator();
|
||||||
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
|
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
|
||||||
|
|
||||||
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue());
|
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response",
|
||||||
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
|
parentStatusCode.getValue());
|
||||||
assertFalse(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(childStatusCode.getValue())));
|
assertThat(
|
||||||
|
result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage)));
|
||||||
|
assertThat(result.getErrors()
|
||||||
|
.stream()
|
||||||
|
.noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue())));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void setsParentAndChildStatusCodeOnResultDescription() {
|
public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() {
|
||||||
ResponseToken mockResponseToken = mock(ResponseToken.class);
|
ResponseToken mockResponseToken = mock(ResponseToken.class);
|
||||||
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
|
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
|
||||||
given(mockResponseToken.getToken()).willReturn(mockSamlToken);
|
given(mockResponseToken.getToken()).willReturn(mockSamlToken);
|
||||||
@ -780,7 +784,8 @@ public class OpenSaml4AuthenticationProviderTests {
|
|||||||
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
|
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
|
||||||
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
|
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
|
||||||
|
|
||||||
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
|
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
|
||||||
|
RelyingPartyRegistration.AssertingPartyDetails.class);
|
||||||
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
|
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
|
||||||
|
|
||||||
Status parentStatus = new StatusBuilder().buildObject();
|
Status parentStatus = new StatusBuilder().buildObject();
|
||||||
@ -799,11 +804,20 @@ public class OpenSaml4AuthenticationProviderTests {
|
|||||||
|
|
||||||
given(mockResponseToken.getResponse()).willReturn(mockResponse);
|
given(mockResponseToken.getResponse()).willReturn(mockResponse);
|
||||||
|
|
||||||
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
|
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
|
||||||
|
.createDefaultResponseValidator();
|
||||||
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
|
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
|
||||||
|
|
||||||
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue() + childStatusCode.getValue());
|
String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response",
|
||||||
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
|
parentStatusCode.getValue());
|
||||||
|
String expectedChildErrorMessage = String.format("Invalid status [%s] for SAML response",
|
||||||
|
childStatusCode.getValue());
|
||||||
|
assertThat(result.getErrors()
|
||||||
|
.stream()
|
||||||
|
.anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage)));
|
||||||
|
assertThat(result.getErrors()
|
||||||
|
.stream()
|
||||||
|
.anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
x
Reference in New Issue
Block a user