diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java index 1f0d5c19af..7b850a5483 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.metadata.AssertionConsumerService; import org.opensaml.saml.saml2.metadata.EntityDescriptor; import org.opensaml.saml.saml2.metadata.KeyDescriptor; +import org.opensaml.saml.saml2.metadata.NameIDFormat; import org.opensaml.saml.saml2.metadata.SPSSODescriptor; import org.opensaml.saml.saml2.metadata.SingleLogoutService; import org.opensaml.saml.saml2.metadata.impl.EntityDescriptorMarshaller; @@ -87,6 +88,9 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver { .addAll(buildKeys(registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION)); spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration)); spSsoDescriptor.getSingleLogoutServices().add(buildSingleLogoutService(registration)); + if (registration.getNameIdFormat() != null) { + spSsoDescriptor.getNameIDFormats().add(buildNameIDFormat(registration)); + } return spSsoDescriptor; } @@ -133,6 +137,12 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver { return singleLogoutService; } + private NameIDFormat buildNameIDFormat(RelyingPartyRegistration registration) { + NameIDFormat nameIdFormat = build(NameIDFormat.DEFAULT_ELEMENT_NAME); + nameIdFormat.setFormat(registration.getNameIdFormat()); + return nameIdFormat; + } + @SuppressWarnings("unchecked") private T build(QName elementName) { XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java index d07a3664f8..43e61b11e1 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java @@ -87,6 +87,8 @@ public final class RelyingPartyRegistration { private final Saml2MessageBinding singleLogoutServiceBinding; + private final String nameIdFormat; + private final ProviderDetails providerDetails; private final List credentials; @@ -98,7 +100,7 @@ public final class RelyingPartyRegistration { private RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation, Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation, String singleLogoutServiceResponseLocation, Saml2MessageBinding singleLogoutServiceBinding, - ProviderDetails providerDetails, + ProviderDetails providerDetails, String nameIdFormat, Collection credentials, Collection decryptionX509Credentials, Collection signingX509Credentials) { @@ -129,6 +131,7 @@ public final class RelyingPartyRegistration { this.singleLogoutServiceLocation = singleLogoutServiceLocation; this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation; this.singleLogoutServiceBinding = singleLogoutServiceBinding; + this.nameIdFormat = nameIdFormat; this.providerDetails = providerDetails; this.credentials = Collections.unmodifiableList(new LinkedList<>(credentials)); this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials)); @@ -234,6 +237,15 @@ public final class RelyingPartyRegistration { return this.singleLogoutServiceResponseLocation; } + /** + * Get the NameID format. + * @return the NameID format + * @since 5.7 + */ + public String getNameIdFormat() { + return this.nameIdFormat; + } + /** * Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated * with this relying party @@ -424,6 +436,7 @@ public final class RelyingPartyRegistration { .singleLogoutServiceLocation(registration.getSingleLogoutServiceLocation()) .singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation()) .singleLogoutServiceBinding(registration.getSingleLogoutServiceBinding()) + .nameIdFormat(registration.getNameIdFormat()) .assertingPartyDetails((assertingParty) -> assertingParty .entityId(registration.getAssertingPartyDetails().getEntityId()) .wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) @@ -1018,6 +1031,8 @@ public final class RelyingPartyRegistration { private Saml2MessageBinding singleLogoutServiceBinding = Saml2MessageBinding.POST; + private String nameIdFormat = null; + private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder(); private Collection credentials = new HashSet<>(); @@ -1173,6 +1188,17 @@ public final class RelyingPartyRegistration { return this; } + /** + * Set the NameID format + * @param nameIdFormat + * @return the {@link Builder} for further configuration + * @since 5.7 + */ + public Builder nameIdFormat(String nameIdFormat) { + this.nameIdFormat = nameIdFormat; + return this; + } + /** * Apply this {@link Consumer} to further configure the Asserting Party details * @param assertingPartyDetails The {@link Consumer} to apply @@ -1321,7 +1347,7 @@ public final class RelyingPartyRegistration { return new RelyingPartyRegistration(this.registrationId, this.entityId, this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding, this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation, - this.singleLogoutServiceBinding, this.providerDetails.build(), this.credentials, + this.singleLogoutServiceBinding, this.providerDetails.build(), this.nameIdFormat, this.credentials, this.decryptionX509Credentials, this.signingX509Credentials); } diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationRequestFactory.java index dcfa1cfdbc..ec02ca2a06 100644 --- a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationRequestFactory.java +++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationRequestFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,8 +27,10 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistry; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.NameIDPolicy; import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder; import org.opensaml.saml.saml2.core.impl.IssuerBuilder; +import org.opensaml.saml.saml2.core.impl.NameIDPolicyBuilder; import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.core.OpenSamlInitializationService; @@ -56,6 +58,8 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent private final IssuerBuilder issuerBuilder; + private final NameIDPolicyBuilder nameIdPolicyBuilder; + private Clock clock = Clock.systemUTC(); private Converter authenticationRequestContextConverter; @@ -69,6 +73,8 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent this.authnRequestBuilder = (AuthnRequestBuilder) registry.getBuilderFactory() .getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME); this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME); + this.nameIdPolicyBuilder = (NameIDPolicyBuilder) registry.getBuilderFactory() + .getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME); } /** @@ -152,6 +158,9 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent auth.setProtocolBinding(SAMLConstants.SAML2_POST_BINDING_URI); } auth.setProtocolBinding(protocolBinding); + if (auth.getNameIDPolicy() == null) { + setNameIdPolicy(auth, context.getRelyingPartyRegistration()); + } Issuer iss = this.issuerBuilder.buildObject(); iss.setValue(issuer); auth.setIssuer(iss); @@ -160,6 +169,15 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent return auth; } + private void setNameIdPolicy(AuthnRequest authnRequest, RelyingPartyRegistration registration) { + if (!StringUtils.hasText(registration.getNameIdFormat())) { + return; + } + NameIDPolicy nameIdPolicy = this.nameIdPolicyBuilder.buildObject(); + nameIdPolicy.setFormat(registration.getNameIdFormat()); + authnRequest.setNameIDPolicy(nameIdPolicy); + } + /** * Set the strategy for building an {@link AuthnRequest} from a given context * @param authenticationRequestContextConverter the conversion strategy to use diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationRequestFactoryTests.java index 84c415ebe5..0aced67097 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationRequestFactoryTests.java @@ -242,6 +242,18 @@ public class OpenSaml4AuthenticationRequestFactoryTests { assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); } + @Test + public void createAuthenticationRequestWhenSetNameIDPolicyThenReturnsCorrectNameIDPolicy() { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().nameIdFormat("format").build(); + this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration) + .build(); + AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST); + assertThat(authn.getNameIDPolicy()).isNotNull(); + assertThat(authn.getNameIDPolicy().getAllowCreate()).isFalse(); + assertThat(authn.getNameIDPolicy().getFormat()).isEqualTo("format"); + assertThat(authn.getNameIDPolicy().getSPNameQualifier()).isNull(); + } + private AuthnRequest authnRequest() { AuthnRequest authnRequest = TestOpenSamlObjects.authnRequest(); authnRequest.setIssueInstant(Instant.now()); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java index d42fc875be..be2069ab94 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,4 +61,13 @@ public class OpenSamlMetadataResolverTests { .contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\""); } + @Test + public void resolveWhenRelyingPartyNameIDFormatThenMetadataMatches() { + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full().nameIdFormat("format") + .build(); + OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver(); + String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration); + assertThat(metadata).contains("format"); + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java index d25d4b981c..63e9d58505 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java @@ -28,6 +28,7 @@ public class RelyingPartyRegistrationTests { @Test public void withRelyingPartyRegistrationWorks() { RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration() + .nameIdFormat("format") .assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST)) .assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false)) .assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg"))) @@ -74,6 +75,7 @@ public class RelyingPartyRegistrationTests { .isEqualTo(registration.getAssertingPartyDetails().getVerificationX509Credentials()); assertThat(copy.getAssertingPartyDetails().getSigningAlgorithms()) .isEqualTo(registration.getAssertingPartyDetails().getSigningAlgorithms()); + assertThat(copy.getNameIdFormat()).isEqualTo(registration.getNameIdFormat()); } @Test