diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java index 130172e3fb..f554258b3a 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java @@ -29,6 +29,7 @@ import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.Issuer; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.credentials.Saml2X509Credential; import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder; import org.springframework.util.Assert; @@ -43,7 +44,14 @@ import static org.springframework.security.saml2.provider.service.authentication public class OpenSamlAuthenticationRequestFactory implements Saml2AuthenticationRequestFactory { private Clock clock = Clock.systemUTC(); private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance(); - private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI; + + private Converter protocolBindingResolver = + context -> { + if (context == null) { + return SAMLConstants.SAML2_POST_BINDING_URI; + } + return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn(); + }; private Function> authnRequestConsumerResolver = context -> authnRequest -> {}; @@ -52,7 +60,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication @Deprecated public String createAuthenticationRequest(Saml2AuthenticationRequest request) { AuthnRequest authnRequest = createAuthnRequest(request.getIssuer(), - request.getDestination(), request.getAssertionConsumerServiceUrl()); + request.getDestination(), request.getAssertionConsumerServiceUrl(), + this.protocolBindingResolver.convert(null)); return this.saml.serialize(authnRequest, request.getCredentials()); } @@ -101,12 +110,14 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) { AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(), - context.getDestination(), context.getAssertionConsumerServiceUrl()); + context.getDestination(), context.getAssertionConsumerServiceUrl(), + this.protocolBindingResolver.convert(context)); this.authnRequestConsumerResolver.apply(context).accept(authnRequest); return authnRequest; } - private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) { + private AuthnRequest createAuthnRequest + (String issuer, String destination, String assertionConsumerServiceUrl, String protocolBinding) { AuthnRequest auth = this.saml.buildSamlObject(AuthnRequest.DEFAULT_ELEMENT_NAME); auth.setID("ARQ" + UUID.randomUUID().toString().substring(1)); auth.setIssueInstant(new DateTime(this.clock.millis())); @@ -155,13 +166,16 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication * @param protocolBinding either {@link SAMLConstants#SAML2_POST_BINDING_URI} or * {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI} * @throws IllegalArgumentException if the protocolBinding is not valid + * @deprecated Use {@link org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.Builder#assertionConsumerServiceBinding} + * instead */ + @Deprecated public void setProtocolBinding(String protocolBinding) { boolean isAllowedBinding = SAMLConstants.SAML2_POST_BINDING_URI.equals(protocolBinding) || SAMLConstants.SAML2_REDIRECT_BINDING_URI.equals(protocolBinding); if (!isAllowedBinding) { throw new IllegalArgumentException("Invalid protocol binding: " + protocolBinding); } - this.protocolBinding = protocolBinding; + this.protocolBindingResolver = context -> protocolBinding; } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java index 14c29c23d1..0afe03f026 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java @@ -68,6 +68,7 @@ public class RelyingPartyRegistration { private final String registrationId; private final String entityId; private final String assertionConsumerServiceLocation; + private final Saml2MessageBinding assertionConsumerServiceBinding; private final ProviderDetails providerDetails; private final List credentials; @@ -75,12 +76,14 @@ public class RelyingPartyRegistration { String registrationId, String entityId, String assertionConsumerServiceLocation, + Saml2MessageBinding assertionConsumerServiceBinding, ProviderDetails providerDetails, List credentials) { Assert.hasText(registrationId, "registrationId cannot be empty"); Assert.hasText(entityId, "entityId cannot be empty"); Assert.hasText(assertionConsumerServiceLocation, "assertionConsumerServiceLocation cannot be empty"); + Assert.notNull(assertionConsumerServiceBinding, "assertionConsumerServiceBinding cannot be null"); Assert.notNull(providerDetails, "providerDetails cannot be null"); Assert.notEmpty(credentials, "credentials cannot be empty"); for (Saml2X509Credential c : credentials) { @@ -89,6 +92,7 @@ public class RelyingPartyRegistration { this.registrationId = registrationId; this.entityId = entityId; this.assertionConsumerServiceLocation = assertionConsumerServiceLocation; + this.assertionConsumerServiceBinding = assertionConsumerServiceBinding; this.providerDetails = providerDetails; this.credentials = Collections.unmodifiableList(new LinkedList<>(credentials)); } @@ -138,6 +142,18 @@ public class RelyingPartyRegistration { return this.assertionConsumerServiceLocation; } + /** + * Get the AssertionConsumerService Binding. + * Equivalent to the value found in <AssertionConsumerService Binding="..."/> + * in the relying party's <SPSSODescriptor>. + * + * @return the AssertionConsumerService Binding + * @since 5.4 + */ + public Saml2MessageBinding getAssertionConsumerServiceBinding() { + return this.assertionConsumerServiceBinding; + } + /** * Get the configuration details for the Asserting Party * @@ -280,6 +296,7 @@ public class RelyingPartyRegistration { return withRegistrationId(registration.getRegistrationId()) .entityId(registration.getEntityId()) .assertionConsumerServiceLocation(registration.getAssertionConsumerServiceLocation()) + .assertionConsumerServiceBinding(registration.getAssertionConsumerServiceBinding()) .assertingPartyDetails(c -> c .entityId(registration.getAssertingPartyDetails().getEntityId()) .wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) @@ -575,6 +592,7 @@ public class RelyingPartyRegistration { private String registrationId; private String entityId = "{baseUrl}/saml2/service-provider-metadata/{registrationId}"; private String assertionConsumerServiceLocation; + private Saml2MessageBinding assertionConsumerServiceBinding = Saml2MessageBinding.POST; private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder(); private List credentials = new LinkedList<>(); @@ -633,6 +651,23 @@ public class RelyingPartyRegistration { return this; } + /** + * Set the AssertionConsumerService + * Binding. + * + *

+ * Equivalent to the value found in <AssertionConsumerService Binding="..."/> + * in the relying party's <SPSSODescriptor> + * + * @param assertionConsumerServiceBinding + * @return the {@link Builder} for further configuration + * @since 5.4 + */ + public Builder assertionConsumerServiceBinding(Saml2MessageBinding assertionConsumerServiceBinding) { + this.assertionConsumerServiceBinding = assertionConsumerServiceBinding; + return this; + } + /** * Apply this {@link Consumer} to further configure the Asserting Party details * @@ -738,6 +773,7 @@ public class RelyingPartyRegistration { this.registrationId, this.entityId, this.assertionConsumerServiceLocation, + this.assertionConsumerServiceBinding, this.providerDetails.build(), this.credentials ); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java index cd504ee9bb..a273f8bf9f 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java @@ -39,6 +39,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode; +import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate; import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT; @@ -52,19 +53,21 @@ public class OpenSamlAuthenticationRequestFactoryTests { private Saml2AuthenticationRequestContext.Builder contextBuilder; private Saml2AuthenticationRequestContext context; + private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder; + private RelyingPartyRegistration relyingPartyRegistration; + @Rule public ExpectedException exception = ExpectedException.none(); - private RelyingPartyRegistration relyingPartyRegistration; @Before public void setUp() { - relyingPartyRegistration = RelyingPartyRegistration.withRegistrationId("id") + this.relyingPartyRegistrationBuilder = RelyingPartyRegistration.withRegistrationId("id") .assertionConsumerServiceLocation("template") .providerDetails(c -> c.webSsoUrl("https://destination/sso")) .providerDetails(c -> c.entityId("remote-entity-id")) .localEntityIdTemplate("local-entity-id") - .credentials(c -> c.add(relyingPartySigningCredential())) - .build(); + .credentials(c -> c.add(relyingPartySigningCredential())); + this.relyingPartyRegistration = this.relyingPartyRegistrationBuilder.build(); contextBuilder = Saml2AuthenticationRequestContext.builder() .issuer("https://issuer") .relyingPartyRegistration(relyingPartyRegistration) @@ -195,6 +198,20 @@ public class OpenSamlAuthenticationRequestFactoryTests { .isInstanceOf(IllegalArgumentException.class); } + @Test + public void createPostAuthenticationRequestWhenAssertionConsumerServiceBindingThenUses() { + RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationBuilder + .assertionConsumerServiceBinding(REDIRECT) + .build(); + Saml2AuthenticationRequestContext context = this.contextBuilder + .relyingPartyRegistration(relyingPartyRegistration) + .build(); + Saml2PostAuthenticationRequest request = this.factory.createPostAuthenticationRequest(context); + String samlRequest = request.getSamlRequest(); + String inflated = new String(samlDecode(samlRequest)); + assertThat(inflated).contains("ProtocolBinding=\"" + SAMLConstants.SAML2_REDIRECT_BINDING_URI + "\""); + } + private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) { AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ? factory.createRedirectAuthenticationRequest(context) : @@ -202,7 +219,7 @@ public class OpenSamlAuthenticationRequestFactoryTests { String samlRequest = result.getSamlRequest(); assertThat(samlRequest).isNotEmpty(); if (result.getBinding() == REDIRECT) { - samlRequest = Saml2Utils.samlInflate(samlDecode(samlRequest)); + samlRequest = samlInflate(samlDecode(samlRequest)); } else { samlRequest = new String(samlDecode(samlRequest), UTF_8); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java index cdf99d7d71..72db62ce6e 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java @@ -21,6 +21,8 @@ import org.junit.Test; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential; +import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; @@ -31,6 +33,7 @@ public class RelyingPartyRegistrationTests { RelyingPartyRegistration registration = relyingPartyRegistration() .providerDetails(p -> p.binding(POST)) .providerDetails(p -> p.signAuthNRequest(false)) + .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT) .build(); RelyingPartyRegistration copy = RelyingPartyRegistration.withRelyingPartyRegistration(registration).build(); compareRegistrations(registration, copy); @@ -76,5 +79,22 @@ public class RelyingPartyRegistrationTests { .isEqualTo(copy.getAssertingPartyDetails().getWantAuthnRequestsSigned()) .isEqualTo(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) .isFalse(); + assertThat(copy.getAssertionConsumerServiceBinding()) + .isEqualTo(registration.getAssertionConsumerServiceBinding()); + } + + @Test + public void buildWhenUsingDefaultsThenAssertionConsumerServiceBindingDefaultsToPost() { + RelyingPartyRegistration relyingPartyRegistration = withRegistrationId("id") + .entityId("entity-id") + .assertionConsumerServiceLocation("location") + .assertingPartyDetails(assertingParty -> assertingParty + .entityId("entity-id") + .singleSignOnServiceLocation("location")) + .credentials(c -> c.add(relyingPartyVerifyingCredential())) + .build(); + + assertThat(relyingPartyRegistration.getAssertionConsumerServiceBinding()) + .isEqualTo(POST); } }