From 70ad3bf749b6be52a1f71ac29204eaf0333ac394 Mon Sep 17 00:00:00 2001 From: Marcus Da Coregio Date: Thu, 19 Oct 2023 09:58:47 -0300 Subject: [PATCH] relay_state should not be included in signing calculation when it is null Closes gh-13913 --- ...OpenSamlAuthenticationRequestResolver.java | 10 +-- ...amlAuthenticationRequestResolverTests.java | 65 ++++++++++++++++++- 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java index 6e2d1d5395..a66b4dd805 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java @@ -167,10 +167,12 @@ class OpenSamlAuthenticationRequestResolver { .samlRequest(deflatedAndEncoded) .relayState(relayState); if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) { - Map parameters = OpenSamlSigningUtils.sign(registration) - .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded) - .param(Saml2ParameterNames.RELAY_STATE, relayState) - .parameters(); + OpenSamlSigningUtils.QueryParametersPartial parametersPartial = OpenSamlSigningUtils.sign(registration) + .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded); + if (relayState != null) { + parametersPartial = parametersPartial.param(Saml2ParameterNames.RELAY_STATE, relayState); + } + Map parameters = parametersPartial.parameters(); builder.sigAlg(parameters.get(Saml2ParameterNames.SIG_ALG)) .signature(parameters.get(Saml2ParameterNames.SIGNATURE)); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java index d314a0166d..400da8de91 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java @@ -18,10 +18,13 @@ package org.springframework.security.saml2.provider.service.web.authentication; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Answers; +import org.mockito.MockedStatic; import org.opensaml.xmlsec.signature.support.SignatureConstants; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2ParameterNames; import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; @@ -32,6 +35,12 @@ import org.springframework.security.saml2.provider.service.registration.TestRely import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; /** * Tests for {@link OpenSamlAuthenticationRequestResolver} @@ -103,8 +112,8 @@ public class OpenSamlAuthenticationRequestResolverTests { .build(); OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); assertThatExceptionOfType(Saml2Exception.class) - .isThrownBy(() -> resolver.resolve(request, (r, authnRequest) -> { - })); + .isThrownBy(() -> resolver.resolve(request, (r, authnRequest) -> { + })); } @Test @@ -172,6 +181,58 @@ public class OpenSamlAuthenticationRequestResolverTests { assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); } + @Test + public void resolveAuthenticationRequestWhenSignedAndRelayStateIsNullThenSignsWithoutRelayState() { + try (MockedStatic openSamlSigningUtilsMockedStatic = mockStatic( + OpenSamlSigningUtils.class, Answers.CALLS_REAL_METHODS)) { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/saml2/authenticate/registration-id"); + RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder + .assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(true)) + .build(); + OpenSamlSigningUtils.QueryParametersPartial queryParametersPartialSpy = spy( + new OpenSamlSigningUtils.QueryParametersPartial(registration)); + openSamlSigningUtilsMockedStatic.when(() -> OpenSamlSigningUtils.sign(any())) + .thenReturn(queryParametersPartialSpy); + OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); + resolver.setRelayStateResolver((source) -> null); + Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { + }); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isNull(); + assertThat(result.getSigAlg()).isNotNull(); + assertThat(result.getSignature()).isNotNull(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + verify(queryParametersPartialSpy, never()).param(eq(Saml2ParameterNames.RELAY_STATE), any()); + } + } + + @Test + public void resolveAuthenticationRequestWhenSignedAndRelayStateIsEmptyThenSignsWithEmptyRelayState() { + try (MockedStatic openSamlSigningUtilsMockedStatic = mockStatic( + OpenSamlSigningUtils.class, Answers.CALLS_REAL_METHODS)) { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/saml2/authenticate/registration-id"); + RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder + .assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(true)) + .build(); + OpenSamlSigningUtils.QueryParametersPartial queryParametersPartialSpy = spy( + new OpenSamlSigningUtils.QueryParametersPartial(registration)); + openSamlSigningUtilsMockedStatic.when(() -> OpenSamlSigningUtils.sign(any())) + .thenReturn(queryParametersPartialSpy); + OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); + resolver.setRelayStateResolver((source) -> ""); + Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { + }); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isEmpty(); + assertThat(result.getSigAlg()).isNotNull(); + assertThat(result.getSignature()).isNotNull(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + verify(queryParametersPartialSpy).param(eq(Saml2ParameterNames.RELAY_STATE), eq("")); + } + } + private OpenSamlAuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistration registration) { return new OpenSamlAuthenticationRequestResolver((request, id) -> registration); }