Add relyingPartyRegistrationId to AbstractSaml2AuthenticationRequest

Closes gh-11195
This commit is contained in:
Ulrich Grave 2022-05-16 09:17:53 +02:00 committed by Josh Cummings
parent 8e34b4c15e
commit 7f5c31995e
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
10 changed files with 154 additions and 28 deletions

View File

@ -47,7 +47,8 @@ class Saml2PostAuthenticationRequestMixin {
@JsonCreator
Saml2PostAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
@JsonProperty("relayState") String relayState,
@JsonProperty("authenticationRequestUri") String authenticationRequestUri) {
@JsonProperty("authenticationRequestUri") String authenticationRequestUri,
@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) {
}
}

View File

@ -48,7 +48,8 @@ class Saml2RedirectAuthenticationRequestMixin {
Saml2RedirectAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
@JsonProperty("sigAlg") String sigAlg, @JsonProperty("signature") String signature,
@JsonProperty("relayState") String relayState,
@JsonProperty("authenticationRequestUri") String authenticationRequestUri) {
@JsonProperty("authenticationRequestUri") String authenticationRequestUri,
@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) {
}
}

View File

@ -20,6 +20,7 @@ import java.io.Serializable;
import java.nio.charset.Charset;
import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert;
@ -46,6 +47,8 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
private final String authenticationRequestUri;
private final String relyingPartyRegistrationId;
/**
* Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest}
* @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or
@ -53,13 +56,17 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
* @param relayState - RelayState value that accompanies the request, may be null
* @param authenticationRequestUri - The authenticationRequestUri, a URL, where to
* send the XML message, cannot be empty or null
* @param relyingPartyRegistrationId the registration id of the relying party, may be
* null
*/
AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri) {
AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
String relyingPartyRegistrationId) {
Assert.hasText(samlRequest, "samlRequest cannot be null or empty");
Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty");
this.authenticationRequestUri = authenticationRequestUri;
this.samlRequest = samlRequest;
this.relayState = relayState;
this.relyingPartyRegistrationId = relyingPartyRegistrationId;
}
/**
@ -89,6 +96,16 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
return this.authenticationRequestUri;
}
/**
* The identifier for the {@link RelyingPartyRegistration} associated with this
* request
* @return the {@link RelyingPartyRegistration} id
* @since 5.8
*/
public String getRelyingPartyRegistrationId() {
return this.relyingPartyRegistrationId;
}
/**
* Returns the binding this AuthNRequest will be sent and encoded with. If
* {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be
@ -108,9 +125,24 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
String relayState;
String relyingPartyRegistrationId;
/**
* @deprecated Use {@link #Builder(RelyingPartyRegistration)} instead
*/
@Deprecated
protected Builder() {
}
/**
* Creates a new Builder with relying party registration
* @param registration the registration of the relying party.
* @sine 5.8
*/
protected Builder(RelyingPartyRegistration registration) {
this.relyingPartyRegistrationId = registration.getRegistrationId();
}
/**
* Casting the return as the generic subtype, when returning itself
* @return this object

View File

@ -30,8 +30,9 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
*/
public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest {
Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri) {
super(samlRequest, relayState, authenticationRequestUri);
Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
String relyingPartyRegistrationId) {
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId);
}
/**
@ -50,7 +51,7 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
*/
public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
return new Builder().authenticationRequestUri(location);
return new Builder(registration).authenticationRequestUri(location);
}
/**
@ -58,7 +59,8 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
*/
public static final class Builder extends AbstractSaml2AuthenticationRequest.Builder<Builder> {
private Builder() {
private Builder(RelyingPartyRegistration registration) {
super(registration);
}
/**
@ -66,7 +68,8 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
* @return an immutable {@link Saml2PostAuthenticationRequest} object.
*/
public Saml2PostAuthenticationRequest build() {
return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri);
return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri,
this.relyingPartyRegistrationId);
}
}

View File

@ -35,8 +35,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
private final String signature;
private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState,
String authenticationRequestUri) {
super(samlRequest, relayState, authenticationRequestUri);
String authenticationRequestUri, String relyingPartyRegistrationId) {
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId);
this.sigAlg = sigAlg;
this.signature = signature;
}
@ -74,7 +74,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
*/
public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
return new Builder().authenticationRequestUri(location);
return new Builder(registration).authenticationRequestUri(location);
}
/**
@ -86,7 +86,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
private String signature;
private Builder() {
private Builder(RelyingPartyRegistration registration) {
super(registration);
}
/**
@ -115,7 +116,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
*/
public Saml2RedirectAuthenticationRequest build() {
return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature,
this.relayState, this.authenticationRequestUri);
this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId);
}
}

View File

@ -26,7 +26,6 @@ import jakarta.servlet.http.HttpServletRequest;
import org.apache.commons.codec.CodecPolicy;
import org.apache.commons.codec.binary.Base64;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpMethod;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
@ -50,25 +49,29 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
private static Base64 BASE64 = new Base64(0, new byte[] { '\n' }, false, CodecPolicy.STRICT);
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
/**
* Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
* resolving {@link RelyingPartyRegistration}s
* @param relyingPartyRegistrationResolver the strategy for resolving
* {@link RelyingPartyRegistration}s
*/
public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
this.relyingPartyRegistrationResolver = adaptToConverter(relyingPartyRegistrationResolver);
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
}
private static Converter<HttpServletRequest, RelyingPartyRegistration> adaptToConverter(
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
return (request) -> relyingPartyRegistrationResolver.resolve(request, null);
}
@Override
public Saml2AuthenticationToken convert(HttpServletRequest request) {
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request);
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
String relyingPartyRegistrationId = (authenticationRequest != null)
? authenticationRequest.getRelyingPartyRegistrationId() : null;
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
relyingPartyRegistrationId);
if (relyingPartyRegistration == null) {
return null;
}
@ -78,7 +81,6 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
}
byte[] b = samlDecode(saml2Response);
saml2Response = inflateIfRequired(request, b);
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
}

View File

@ -56,6 +56,23 @@ class Saml2PostAuthenticationRequestMixinTests {
assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
assertThat(authRequest.getAuthenticationRequestUri())
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
assertThat(authRequest.getRelyingPartyRegistrationId())
.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
}
@Test
void shouldDeserializeWithNoRegistrationId() throws Exception {
String json = TestSaml2JsonPayloads.DEFAULT_POST_AUTH_REQUEST_JSON.replace(
"\"relyingPartyRegistrationId\": \"" + TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID + "\",", "");
Saml2PostAuthenticationRequest authRequest = this.mapper.readValue(json, Saml2PostAuthenticationRequest.class);
assertThat(authRequest).isNotNull();
assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST);
assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
assertThat(authRequest.getAuthenticationRequestUri())
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
assertThat(authRequest.getRelyingPartyRegistrationId()).isNull();
}
}

View File

@ -59,6 +59,26 @@ class Saml2RedirectAuthenticationRequestMixinTests {
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
assertThat(authRequest.getSigAlg()).isEqualTo(TestSaml2JsonPayloads.SIG_ALG);
assertThat(authRequest.getSignature()).isEqualTo(TestSaml2JsonPayloads.SIGNATURE);
assertThat(authRequest.getRelyingPartyRegistrationId())
.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
}
@Test
void shouldDeserializeWithNoRegistrationId() throws Exception {
String json = TestSaml2JsonPayloads.DEFAULT_REDIRECT_AUTH_REQUEST_JSON.replace(
"\"relyingPartyRegistrationId\": \"" + TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID + "\",", "");
Saml2RedirectAuthenticationRequest authRequest = this.mapper.readValue(json,
Saml2RedirectAuthenticationRequest.class);
assertThat(authRequest).isNotNull();
assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST);
assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
assertThat(authRequest.getAuthenticationRequestUri())
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
assertThat(authRequest.getSigAlg()).isEqualTo(TestSaml2JsonPayloads.SIG_ALG);
assertThat(authRequest.getSignature()).isEqualTo(TestSaml2JsonPayloads.SIGNATURE);
assertThat(authRequest.getRelyingPartyRegistrationId()).isNull();
}
}

View File

@ -94,6 +94,7 @@ final class TestSaml2JsonPayloads {
static final String SAML_REQUEST = "samlRequestValue";
static final String RELAY_STATE = "relayStateValue";
static final String AUTHENTICATION_REQUEST_URI = "authenticationRequestUriValue";
static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue";
static final String SIG_ALG = "sigAlgValue";
static final String SIGNATURE = "signatureValue";
@ -103,6 +104,7 @@ final class TestSaml2JsonPayloads {
+ " \"samlRequest\": \"" + SAML_REQUEST + "\","
+ " \"relayState\": \"" + RELAY_STATE + "\","
+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\","
+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
+ " \"sigAlg\": \"" + SIG_ALG + "\","
+ " \"signature\": \"" + SIGNATURE + "\""
+ "}";
@ -113,6 +115,7 @@ final class TestSaml2JsonPayloads {
+ " \"@class\": \"org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest\","
+ " \"samlRequest\": \"" + SAML_REQUEST + "\","
+ " \"relayState\": \"" + RELAY_STATE + "\","
+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\""
+ "}";
// @formatter:on
@ -120,7 +123,6 @@ final class TestSaml2JsonPayloads {
static final String ID = "idValue";
static final String LOCATION = "locationValue";
static final String BINDNG = "REDIRECT";
static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue";
static final String ADDITIONAL_PARAM = "additionalParamValue";
// @formatter:off
@ -140,14 +142,17 @@ final class TestSaml2JsonPayloads {
// @formatter:on
static Saml2PostAuthenticationRequest createDefaultSaml2PostAuthenticationRequest() {
return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(TestRelyingPartyRegistrations.full()
.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
.build()).samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build();
return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(
TestRelyingPartyRegistrations.full().registrationId(RELYINGPARTY_REGISTRATION_ID)
.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
.build())
.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build();
}
static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticationRequest() {
return Saml2RedirectAuthenticationRequest
.withRelyingPartyRegistration(TestRelyingPartyRegistrations.full()
.registrationId(RELYINGPARTY_REGISTRATION_ID)
.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
.build())
.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).build();

View File

@ -42,8 +42,11 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ExtendWith(MockitoExtension.class)
public class Saml2AuthenticationTokenConverterTests {
@ -69,6 +72,21 @@ public class Saml2AuthenticationTokenConverterTests {
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
}
@Test
public void convertWhenSamlResponseWithRelyingPartyRegistrationResolver(
@Mock RelyingPartyRegistrationResolver resolver) {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver);
given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
Saml2Utils.samlEncodeNotRfc2045("response".getBytes(StandardCharsets.UTF_8)));
Saml2AuthenticationToken token = converter.convert(request);
assertThat(token.getSaml2Response()).isEqualTo("response");
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
verify(resolver).resolve(any(), isNull());
}
@Test
public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
@ -157,6 +175,8 @@ public class Saml2AuthenticationTokenConverterTests {
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
Saml2AuthenticationRequestRepository.class);
AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
given(authenticationRequest.getRelyingPartyRegistrationId())
.willReturn(this.relyingPartyRegistration.getRegistrationId());
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
this.relyingPartyRegistrationResolver);
converter.setAuthenticationRequestRepository(authenticationRequestRepository);
@ -174,6 +194,30 @@ public class Saml2AuthenticationTokenConverterTests {
assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
}
@Test
public void convertWhenSavedAuthenticationRequestThenTokenWithRelyingPartyRegistrationResolver(
@Mock RelyingPartyRegistrationResolver resolver) {
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
Saml2AuthenticationRequestRepository.class);
AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
given(authenticationRequest.getRelyingPartyRegistrationId())
.willReturn(this.relyingPartyRegistration.getRegistrationId());
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver);
converter.setAuthenticationRequestRepository(authenticationRequestRepository);
given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration);
given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class)))
.willReturn(authenticationRequest);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
Saml2AuthenticationToken token = converter.convert(request);
assertThat(token.getSaml2Response()).isEqualTo("response");
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
verify(resolver).resolve(any(), eq(this.relyingPartyRegistration.getRegistrationId()));
}
@Test
public void constructorWhenResolverIsNullThenIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null));