Add SP NameIDFormat Support

closes gh-9115
This commit is contained in:
Arnaud Mergey 2020-12-01 16:54:13 +01:00 committed by Josh Cummings
parent a68411566e
commit dbe4d704f8
6 changed files with 82 additions and 5 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -31,6 +31,7 @@ import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.metadata.AssertionConsumerService; import org.opensaml.saml.saml2.metadata.AssertionConsumerService;
import org.opensaml.saml.saml2.metadata.EntityDescriptor; import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.KeyDescriptor; import org.opensaml.saml.saml2.metadata.KeyDescriptor;
import org.opensaml.saml.saml2.metadata.NameIDFormat;
import org.opensaml.saml.saml2.metadata.SPSSODescriptor; import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
import org.opensaml.saml.saml2.metadata.SingleLogoutService; import org.opensaml.saml.saml2.metadata.SingleLogoutService;
import org.opensaml.saml.saml2.metadata.impl.EntityDescriptorMarshaller; import org.opensaml.saml.saml2.metadata.impl.EntityDescriptorMarshaller;
@ -87,6 +88,9 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
.addAll(buildKeys(registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION)); .addAll(buildKeys(registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION));
spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration)); spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration));
spSsoDescriptor.getSingleLogoutServices().add(buildSingleLogoutService(registration)); spSsoDescriptor.getSingleLogoutServices().add(buildSingleLogoutService(registration));
if (registration.getNameIdFormat() != null) {
spSsoDescriptor.getNameIDFormats().add(buildNameIDFormat(registration));
}
return spSsoDescriptor; return spSsoDescriptor;
} }
@ -133,6 +137,12 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
return singleLogoutService; return singleLogoutService;
} }
private NameIDFormat buildNameIDFormat(RelyingPartyRegistration registration) {
NameIDFormat nameIdFormat = build(NameIDFormat.DEFAULT_ELEMENT_NAME);
nameIdFormat.setFormat(registration.getNameIdFormat());
return nameIdFormat;
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private <T> T build(QName elementName) { private <T> T build(QName elementName) {
XMLObjectBuilder<?> builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); XMLObjectBuilder<?> builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName);

View File

@ -87,6 +87,8 @@ public final class RelyingPartyRegistration {
private final Saml2MessageBinding singleLogoutServiceBinding; private final Saml2MessageBinding singleLogoutServiceBinding;
private final String nameIdFormat;
private final ProviderDetails providerDetails; private final ProviderDetails providerDetails;
private final List<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials; private final List<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials;
@ -98,7 +100,7 @@ public final class RelyingPartyRegistration {
private RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation, private RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation,
Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation, Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation,
String singleLogoutServiceResponseLocation, Saml2MessageBinding singleLogoutServiceBinding, String singleLogoutServiceResponseLocation, Saml2MessageBinding singleLogoutServiceBinding,
ProviderDetails providerDetails, ProviderDetails providerDetails, String nameIdFormat,
Collection<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials, Collection<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials,
Collection<Saml2X509Credential> decryptionX509Credentials, Collection<Saml2X509Credential> decryptionX509Credentials,
Collection<Saml2X509Credential> signingX509Credentials) { Collection<Saml2X509Credential> signingX509Credentials) {
@ -129,6 +131,7 @@ public final class RelyingPartyRegistration {
this.singleLogoutServiceLocation = singleLogoutServiceLocation; this.singleLogoutServiceLocation = singleLogoutServiceLocation;
this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation; this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation;
this.singleLogoutServiceBinding = singleLogoutServiceBinding; this.singleLogoutServiceBinding = singleLogoutServiceBinding;
this.nameIdFormat = nameIdFormat;
this.providerDetails = providerDetails; this.providerDetails = providerDetails;
this.credentials = Collections.unmodifiableList(new LinkedList<>(credentials)); this.credentials = Collections.unmodifiableList(new LinkedList<>(credentials));
this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials)); this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials));
@ -234,6 +237,15 @@ public final class RelyingPartyRegistration {
return this.singleLogoutServiceResponseLocation; return this.singleLogoutServiceResponseLocation;
} }
/**
* Get the NameID format.
* @return the NameID format
* @since 5.7
*/
public String getNameIdFormat() {
return this.nameIdFormat;
}
/** /**
* Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated * Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated
* with this relying party * with this relying party
@ -424,6 +436,7 @@ public final class RelyingPartyRegistration {
.singleLogoutServiceLocation(registration.getSingleLogoutServiceLocation()) .singleLogoutServiceLocation(registration.getSingleLogoutServiceLocation())
.singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation()) .singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation())
.singleLogoutServiceBinding(registration.getSingleLogoutServiceBinding()) .singleLogoutServiceBinding(registration.getSingleLogoutServiceBinding())
.nameIdFormat(registration.getNameIdFormat())
.assertingPartyDetails((assertingParty) -> assertingParty .assertingPartyDetails((assertingParty) -> assertingParty
.entityId(registration.getAssertingPartyDetails().getEntityId()) .entityId(registration.getAssertingPartyDetails().getEntityId())
.wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) .wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned())
@ -1018,6 +1031,8 @@ public final class RelyingPartyRegistration {
private Saml2MessageBinding singleLogoutServiceBinding = Saml2MessageBinding.POST; private Saml2MessageBinding singleLogoutServiceBinding = Saml2MessageBinding.POST;
private String nameIdFormat = null;
private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder(); private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder();
private Collection<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials = new HashSet<>(); private Collection<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials = new HashSet<>();
@ -1173,6 +1188,17 @@ public final class RelyingPartyRegistration {
return this; return this;
} }
/**
* Set the NameID format
* @param nameIdFormat
* @return the {@link Builder} for further configuration
* @since 5.7
*/
public Builder nameIdFormat(String nameIdFormat) {
this.nameIdFormat = nameIdFormat;
return this;
}
/** /**
* Apply this {@link Consumer} to further configure the Asserting Party details * Apply this {@link Consumer} to further configure the Asserting Party details
* @param assertingPartyDetails The {@link Consumer} to apply * @param assertingPartyDetails The {@link Consumer} to apply
@ -1321,7 +1347,7 @@ public final class RelyingPartyRegistration {
return new RelyingPartyRegistration(this.registrationId, this.entityId, return new RelyingPartyRegistration(this.registrationId, this.entityId,
this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding, this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding,
this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation, this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation,
this.singleLogoutServiceBinding, this.providerDetails.build(), this.credentials, this.singleLogoutServiceBinding, this.providerDetails.build(), this.nameIdFormat, this.credentials,
this.decryptionX509Credentials, this.signingX509Credentials); this.decryptionX509Credentials, this.signingX509Credentials);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -27,8 +27,10 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Issuer; import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.NameIDPolicy;
import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder; import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
import org.opensaml.saml.saml2.core.impl.IssuerBuilder; import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
import org.opensaml.saml.saml2.core.impl.NameIDPolicyBuilder;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.OpenSamlInitializationService;
@ -56,6 +58,8 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
private final IssuerBuilder issuerBuilder; private final IssuerBuilder issuerBuilder;
private final NameIDPolicyBuilder nameIdPolicyBuilder;
private Clock clock = Clock.systemUTC(); private Clock clock = Clock.systemUTC();
private Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter; private Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter;
@ -69,6 +73,8 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
this.authnRequestBuilder = (AuthnRequestBuilder) registry.getBuilderFactory() this.authnRequestBuilder = (AuthnRequestBuilder) registry.getBuilderFactory()
.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME); .getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME);
this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME); this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
this.nameIdPolicyBuilder = (NameIDPolicyBuilder) registry.getBuilderFactory()
.getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME);
} }
/** /**
@ -152,6 +158,9 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
auth.setProtocolBinding(SAMLConstants.SAML2_POST_BINDING_URI); auth.setProtocolBinding(SAMLConstants.SAML2_POST_BINDING_URI);
} }
auth.setProtocolBinding(protocolBinding); auth.setProtocolBinding(protocolBinding);
if (auth.getNameIDPolicy() == null) {
setNameIdPolicy(auth, context.getRelyingPartyRegistration());
}
Issuer iss = this.issuerBuilder.buildObject(); Issuer iss = this.issuerBuilder.buildObject();
iss.setValue(issuer); iss.setValue(issuer);
auth.setIssuer(iss); auth.setIssuer(iss);
@ -160,6 +169,15 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
return auth; return auth;
} }
private void setNameIdPolicy(AuthnRequest authnRequest, RelyingPartyRegistration registration) {
if (!StringUtils.hasText(registration.getNameIdFormat())) {
return;
}
NameIDPolicy nameIdPolicy = this.nameIdPolicyBuilder.buildObject();
nameIdPolicy.setFormat(registration.getNameIdFormat());
authnRequest.setNameIDPolicy(nameIdPolicy);
}
/** /**
* Set the strategy for building an {@link AuthnRequest} from a given context * Set the strategy for building an {@link AuthnRequest} from a given context
* @param authenticationRequestContextConverter the conversion strategy to use * @param authenticationRequestContextConverter the conversion strategy to use

View File

@ -242,6 +242,18 @@ public class OpenSaml4AuthenticationRequestFactoryTests {
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
} }
@Test
public void createAuthenticationRequestWhenSetNameIDPolicyThenReturnsCorrectNameIDPolicy() {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().nameIdFormat("format").build();
this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
.build();
AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);
assertThat(authn.getNameIDPolicy()).isNotNull();
assertThat(authn.getNameIDPolicy().getAllowCreate()).isFalse();
assertThat(authn.getNameIDPolicy().getFormat()).isEqualTo("format");
assertThat(authn.getNameIDPolicy().getSPNameQualifier()).isNull();
}
private AuthnRequest authnRequest() { private AuthnRequest authnRequest() {
AuthnRequest authnRequest = TestOpenSamlObjects.authnRequest(); AuthnRequest authnRequest = TestOpenSamlObjects.authnRequest();
authnRequest.setIssueInstant(Instant.now()); authnRequest.setIssueInstant(Instant.now());

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -61,4 +61,13 @@ public class OpenSamlMetadataResolverTests {
.contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\""); .contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\"");
} }
@Test
public void resolveWhenRelyingPartyNameIDFormatThenMetadataMatches() {
RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full().nameIdFormat("format")
.build();
OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration);
assertThat(metadata).contains("<md:NameIDFormat>format</md:NameIDFormat>");
}
} }

View File

@ -28,6 +28,7 @@ public class RelyingPartyRegistrationTests {
@Test @Test
public void withRelyingPartyRegistrationWorks() { public void withRelyingPartyRegistrationWorks() {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration() RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
.nameIdFormat("format")
.assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST)) .assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST))
.assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false)) .assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false))
.assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg"))) .assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg")))
@ -74,6 +75,7 @@ public class RelyingPartyRegistrationTests {
.isEqualTo(registration.getAssertingPartyDetails().getVerificationX509Credentials()); .isEqualTo(registration.getAssertingPartyDetails().getVerificationX509Credentials());
assertThat(copy.getAssertingPartyDetails().getSigningAlgorithms()) assertThat(copy.getAssertingPartyDetails().getSigningAlgorithms())
.isEqualTo(registration.getAssertingPartyDetails().getSigningAlgorithms()); .isEqualTo(registration.getAssertingPartyDetails().getSigningAlgorithms());
assertThat(copy.getNameIdFormat()).isEqualTo(registration.getNameIdFormat());
} }
@Test @Test