Add ID to Saml2 Post and Redirect Requests

Closes gh-11468
This commit is contained in:
Scott Shidlovsky 2022-07-12 11:02:18 -04:00 committed by Josh Cummings
parent 15f525c614
commit 947445fcc5
8 changed files with 69 additions and 16 deletions

View File

@ -48,7 +48,8 @@ class Saml2PostAuthenticationRequestMixin {
Saml2PostAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest, Saml2PostAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
@JsonProperty("relayState") String relayState, @JsonProperty("relayState") String relayState,
@JsonProperty("authenticationRequestUri") String authenticationRequestUri, @JsonProperty("authenticationRequestUri") String authenticationRequestUri,
@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) { @JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId,
@JsonProperty("id") String id) {
} }
} }

View File

@ -49,7 +49,8 @@ class Saml2RedirectAuthenticationRequestMixin {
@JsonProperty("sigAlg") String sigAlg, @JsonProperty("signature") String signature, @JsonProperty("sigAlg") String sigAlg, @JsonProperty("signature") String signature,
@JsonProperty("relayState") String relayState, @JsonProperty("relayState") String relayState,
@JsonProperty("authenticationRequestUri") String authenticationRequestUri, @JsonProperty("authenticationRequestUri") String authenticationRequestUri,
@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) { @JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId,
@JsonProperty("id") String id) {
} }
} }

View File

@ -49,6 +49,8 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
private final String relyingPartyRegistrationId; private final String relyingPartyRegistrationId;
private final String id;
/** /**
* Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest} * Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest}
* @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or * @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or
@ -58,15 +60,18 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
* send the XML message, cannot be empty or null * send the XML message, cannot be empty or null
* @param relyingPartyRegistrationId the registration id of the relying party, may be * @param relyingPartyRegistrationId the registration id of the relying party, may be
* null * null
* @param id This is the unique id used in the {@link #samlRequest}, cannot be empty
* or null
*/ */
AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri, AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
String relyingPartyRegistrationId) { String relyingPartyRegistrationId, String id) {
Assert.hasText(samlRequest, "samlRequest cannot be null or empty"); Assert.hasText(samlRequest, "samlRequest cannot be null or empty");
Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty"); Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty");
this.authenticationRequestUri = authenticationRequestUri; this.authenticationRequestUri = authenticationRequestUri;
this.samlRequest = samlRequest; this.samlRequest = samlRequest;
this.relayState = relayState; this.relayState = relayState;
this.relyingPartyRegistrationId = relyingPartyRegistrationId; this.relyingPartyRegistrationId = relyingPartyRegistrationId;
this.id = id;
} }
/** /**
@ -106,6 +111,15 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
return this.relyingPartyRegistrationId; return this.relyingPartyRegistrationId;
} }
/**
* The unique identifier for this Authentication Request
* @return the Authentication Request identifier
* @since 5.8
*/
public String getId() {
return this.id;
}
/** /**
* Returns the binding this AuthNRequest will be sent and encoded with. If * Returns the binding this AuthNRequest will be sent and encoded with. If
* {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be * {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be
@ -127,6 +141,8 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
String relyingPartyRegistrationId; String relyingPartyRegistrationId;
String id;
/** /**
* @deprecated Use {@link #Builder(RelyingPartyRegistration)} instead * @deprecated Use {@link #Builder(RelyingPartyRegistration)} instead
*/ */
@ -184,6 +200,19 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
return _this(); return _this();
} }
/**
* This is the unique id used in the {@link #samlRequest}
* @param id the SAML2 request id
* @return the {@link AbstractSaml2AuthenticationRequest.Builder} for further
* configurations
* @since 5.8
*/
public T id(String id) {
Assert.notNull(id, "id cannot be null");
this.id = id;
return _this();
}
} }
} }

View File

@ -31,8 +31,8 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest { public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest {
Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri, Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
String relyingPartyRegistrationId) { String relyingPartyRegistrationId, String id) {
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId); super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId, id);
} }
/** /**
@ -69,7 +69,7 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
*/ */
public Saml2PostAuthenticationRequest build() { public Saml2PostAuthenticationRequest build() {
return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri, return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri,
this.relyingPartyRegistrationId); this.relyingPartyRegistrationId, this.id);
} }
} }

View File

@ -35,8 +35,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
private final String signature; private final String signature;
private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState, private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState,
String authenticationRequestUri, String relyingPartyRegistrationId) { String authenticationRequestUri, String relyingPartyRegistrationId, String id) {
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId); super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId, id);
this.sigAlg = sigAlg; this.sigAlg = sigAlg;
this.signature = signature; this.signature = signature;
} }
@ -116,7 +116,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
*/ */
public Saml2RedirectAuthenticationRequest build() { public Saml2RedirectAuthenticationRequest build() {
return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature, return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature,
this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId); this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId, this.id);
} }
} }

View File

@ -142,13 +142,14 @@ class OpenSamlAuthenticationRequestResolver {
String xml = serialize(authnRequest); String xml = serialize(authnRequest);
String encoded = Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8)); String encoded = Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8));
return (T) Saml2PostAuthenticationRequest.withRelyingPartyRegistration(registration).samlRequest(encoded) return (T) Saml2PostAuthenticationRequest.withRelyingPartyRegistration(registration).samlRequest(encoded)
.relayState(relayState).build(); .relayState(relayState).id(authnRequest.getID()).build();
} }
else { else {
String xml = serialize(authnRequest); String xml = serialize(authnRequest);
String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest
.withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState); .withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState)
.id(authnRequest.getID());
if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) { if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
Map<String, String> parameters = OpenSamlSigningUtils.sign(registration) Map<String, String> parameters = OpenSamlSigningUtils.sign(registration)
.param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded) .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded)

View File

@ -58,6 +58,7 @@ class Saml2PostAuthenticationRequestMixinTests {
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI); .isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
assertThat(authRequest.getRelyingPartyRegistrationId()) assertThat(authRequest.getRelyingPartyRegistrationId())
.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID); .isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
assertThat(authRequest.getId()).isEqualTo(TestSaml2JsonPayloads.ID);
} }
@Test @Test
@ -73,6 +74,24 @@ class Saml2PostAuthenticationRequestMixinTests {
assertThat(authRequest.getAuthenticationRequestUri()) assertThat(authRequest.getAuthenticationRequestUri())
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI); .isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
assertThat(authRequest.getRelyingPartyRegistrationId()).isNull(); assertThat(authRequest.getRelyingPartyRegistrationId()).isNull();
assertThat(authRequest.getId()).isEqualTo(TestSaml2JsonPayloads.ID);
}
@Test
void shouldDeserializeWithNoId() throws Exception {
String json = TestSaml2JsonPayloads.DEFAULT_POST_AUTH_REQUEST_JSON
.replace(", \"id\": \"" + TestSaml2JsonPayloads.ID + "\"", "");
Saml2PostAuthenticationRequest authRequest = this.mapper.readValue(json, Saml2PostAuthenticationRequest.class);
assertThat(authRequest).isNotNull();
assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST);
assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
assertThat(authRequest.getAuthenticationRequestUri())
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
assertThat(authRequest.getRelyingPartyRegistrationId())
.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
assertThat(authRequest.getId()).isNull();
} }
} }

View File

@ -97,6 +97,7 @@ final class TestSaml2JsonPayloads {
static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue"; static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue";
static final String SIG_ALG = "sigAlgValue"; static final String SIG_ALG = "sigAlgValue";
static final String SIGNATURE = "signatureValue"; static final String SIGNATURE = "signatureValue";
static final String ID = "idValue";
// @formatter:off // @formatter:off
static final String DEFAULT_REDIRECT_AUTH_REQUEST_JSON = "{" static final String DEFAULT_REDIRECT_AUTH_REQUEST_JSON = "{"
@ -106,7 +107,8 @@ final class TestSaml2JsonPayloads {
+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\"," + " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\","
+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\"," + " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
+ " \"sigAlg\": \"" + SIG_ALG + "\"," + " \"sigAlg\": \"" + SIG_ALG + "\","
+ " \"signature\": \"" + SIGNATURE + "\"" + " \"signature\": \"" + SIGNATURE + "\","
+ " \"id\": \"" + ID + "\""
+ "}"; + "}";
// @formatter:on // @formatter:on
@ -116,11 +118,11 @@ final class TestSaml2JsonPayloads {
+ " \"samlRequest\": \"" + SAML_REQUEST + "\"," + " \"samlRequest\": \"" + SAML_REQUEST + "\","
+ " \"relayState\": \"" + RELAY_STATE + "\"," + " \"relayState\": \"" + RELAY_STATE + "\","
+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\"," + " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\"" + " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\","
+ " \"id\": \"" + ID + "\""
+ "}"; + "}";
// @formatter:on // @formatter:on
static final String ID = "idValue";
static final String LOCATION = "locationValue"; static final String LOCATION = "locationValue";
static final String BINDNG = "REDIRECT"; static final String BINDNG = "REDIRECT";
static final String ADDITIONAL_PARAM = "additionalParamValue"; static final String ADDITIONAL_PARAM = "additionalParamValue";
@ -146,7 +148,7 @@ final class TestSaml2JsonPayloads {
TestRelyingPartyRegistrations.full().registrationId(RELYINGPARTY_REGISTRATION_ID) TestRelyingPartyRegistrations.full().registrationId(RELYINGPARTY_REGISTRATION_ID)
.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI)) .assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
.build()) .build())
.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build(); .samlRequest(SAML_REQUEST).relayState(RELAY_STATE).id(ID).build();
} }
static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticationRequest() { static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticationRequest() {
@ -155,7 +157,7 @@ final class TestSaml2JsonPayloads {
.registrationId(RELYINGPARTY_REGISTRATION_ID) .registrationId(RELYINGPARTY_REGISTRATION_ID)
.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI)) .assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
.build()) .build())
.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).build(); .samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).id(ID).build();
} }
static Saml2LogoutRequest createDefaultSaml2LogoutRequest() { static Saml2LogoutRequest createDefaultSaml2LogoutRequest() {