diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java index 0f51590571..b49aacb287 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java @@ -176,10 +176,12 @@ class OpenSamlAuthenticationRequestResolver { .id(authnRequest.getID()); if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned() || registration.isAuthnRequestsSigned()) { - Map parameters = OpenSamlSigningUtils.sign(registration) - .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded) - .param(Saml2ParameterNames.RELAY_STATE, relayState) - .parameters(); + OpenSamlSigningUtils.QueryParametersPartial parametersPartial = OpenSamlSigningUtils.sign(registration) + .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded); + if (relayState != null) { + parametersPartial = parametersPartial.param(Saml2ParameterNames.RELAY_STATE, relayState); + } + Map parameters = parametersPartial.parameters(); builder.sigAlg(parameters.get(Saml2ParameterNames.SIG_ALG)) .signature(parameters.get(Saml2ParameterNames.SIGNATURE)); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java index 6f15358377..b4088bac63 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java @@ -23,10 +23,13 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Answers; +import org.mockito.MockedStatic; import org.opensaml.xmlsec.signature.support.SignatureConstants; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2ParameterNames; import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; @@ -39,6 +42,12 @@ import org.springframework.security.saml2.provider.service.web.RelyingPartyRegis import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; /** * Tests for {@link OpenSamlAuthenticationRequestResolver} @@ -198,6 +207,58 @@ public class OpenSamlAuthenticationRequestResolverTests { assertThat(result.getId()).isNotEmpty(); } + @Test + public void resolveAuthenticationRequestWhenSignedAndRelayStateIsNullThenSignsWithoutRelayState() { + try (MockedStatic openSamlSigningUtilsMockedStatic = mockStatic( + OpenSamlSigningUtils.class, Answers.CALLS_REAL_METHODS)) { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/saml2/authenticate/registration-id"); + RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder + .assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(true)) + .build(); + OpenSamlSigningUtils.QueryParametersPartial queryParametersPartialSpy = spy( + new OpenSamlSigningUtils.QueryParametersPartial(registration)); + openSamlSigningUtilsMockedStatic.when(() -> OpenSamlSigningUtils.sign(any())) + .thenReturn(queryParametersPartialSpy); + OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); + resolver.setRelayStateResolver((source) -> null); + Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { + }); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isNull(); + assertThat(result.getSigAlg()).isNotNull(); + assertThat(result.getSignature()).isNotNull(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + verify(queryParametersPartialSpy, never()).param(eq(Saml2ParameterNames.RELAY_STATE), any()); + } + } + + @Test + public void resolveAuthenticationRequestWhenSignedAndRelayStateIsEmptyThenSignsWithEmptyRelayState() { + try (MockedStatic openSamlSigningUtilsMockedStatic = mockStatic( + OpenSamlSigningUtils.class, Answers.CALLS_REAL_METHODS)) { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/saml2/authenticate/registration-id"); + RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder + .assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(true)) + .build(); + OpenSamlSigningUtils.QueryParametersPartial queryParametersPartialSpy = spy( + new OpenSamlSigningUtils.QueryParametersPartial(registration)); + openSamlSigningUtilsMockedStatic.when(() -> OpenSamlSigningUtils.sign(any())) + .thenReturn(queryParametersPartialSpy); + OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); + resolver.setRelayStateResolver((source) -> ""); + Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { + }); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isEmpty(); + assertThat(result.getSigAlg()).isNotNull(); + assertThat(result.getSignature()).isNotNull(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + verify(queryParametersPartialSpy).param(eq(Saml2ParameterNames.RELAY_STATE), eq("")); + } + } + private OpenSamlAuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistration registration) { return new OpenSamlAuthenticationRequestResolver((request, id) -> registration); }