Add SP NameIDFormat Support

closes gh-9115
This commit is contained in:
Arnaud Mergey 2020-12-01 16:54:13 +01:00 committed by Josh Cummings
parent a68411566e
commit dbe4d704f8
6 changed files with 82 additions and 5 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,6 +31,7 @@ import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.metadata.AssertionConsumerService;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.KeyDescriptor;
import org.opensaml.saml.saml2.metadata.NameIDFormat;
import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
import org.opensaml.saml.saml2.metadata.SingleLogoutService;
import org.opensaml.saml.saml2.metadata.impl.EntityDescriptorMarshaller;
@ -87,6 +88,9 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
.addAll(buildKeys(registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION));
spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration));
spSsoDescriptor.getSingleLogoutServices().add(buildSingleLogoutService(registration));
if (registration.getNameIdFormat() != null) {
spSsoDescriptor.getNameIDFormats().add(buildNameIDFormat(registration));
}
return spSsoDescriptor;
}
@ -133,6 +137,12 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
return singleLogoutService;
}
private NameIDFormat buildNameIDFormat(RelyingPartyRegistration registration) {
NameIDFormat nameIdFormat = build(NameIDFormat.DEFAULT_ELEMENT_NAME);
nameIdFormat.setFormat(registration.getNameIdFormat());
return nameIdFormat;
}
@SuppressWarnings("unchecked")
private <T> T build(QName elementName) {
XMLObjectBuilder<?> builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName);

View File

@ -87,6 +87,8 @@ public final class RelyingPartyRegistration {
private final Saml2MessageBinding singleLogoutServiceBinding;
private final String nameIdFormat;
private final ProviderDetails providerDetails;
private final List<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials;
@ -98,7 +100,7 @@ public final class RelyingPartyRegistration {
private RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation,
Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation,
String singleLogoutServiceResponseLocation, Saml2MessageBinding singleLogoutServiceBinding,
ProviderDetails providerDetails,
ProviderDetails providerDetails, String nameIdFormat,
Collection<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials,
Collection<Saml2X509Credential> decryptionX509Credentials,
Collection<Saml2X509Credential> signingX509Credentials) {
@ -129,6 +131,7 @@ public final class RelyingPartyRegistration {
this.singleLogoutServiceLocation = singleLogoutServiceLocation;
this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation;
this.singleLogoutServiceBinding = singleLogoutServiceBinding;
this.nameIdFormat = nameIdFormat;
this.providerDetails = providerDetails;
this.credentials = Collections.unmodifiableList(new LinkedList<>(credentials));
this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials));
@ -234,6 +237,15 @@ public final class RelyingPartyRegistration {
return this.singleLogoutServiceResponseLocation;
}
/**
* Get the NameID format.
* @return the NameID format
* @since 5.7
*/
public String getNameIdFormat() {
return this.nameIdFormat;
}
/**
* Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated
* with this relying party
@ -424,6 +436,7 @@ public final class RelyingPartyRegistration {
.singleLogoutServiceLocation(registration.getSingleLogoutServiceLocation())
.singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation())
.singleLogoutServiceBinding(registration.getSingleLogoutServiceBinding())
.nameIdFormat(registration.getNameIdFormat())
.assertingPartyDetails((assertingParty) -> assertingParty
.entityId(registration.getAssertingPartyDetails().getEntityId())
.wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned())
@ -1018,6 +1031,8 @@ public final class RelyingPartyRegistration {
private Saml2MessageBinding singleLogoutServiceBinding = Saml2MessageBinding.POST;
private String nameIdFormat = null;
private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder();
private Collection<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials = new HashSet<>();
@ -1173,6 +1188,17 @@ public final class RelyingPartyRegistration {
return this;
}
/**
* Set the NameID format
* @param nameIdFormat
* @return the {@link Builder} for further configuration
* @since 5.7
*/
public Builder nameIdFormat(String nameIdFormat) {
this.nameIdFormat = nameIdFormat;
return this;
}
/**
* Apply this {@link Consumer} to further configure the Asserting Party details
* @param assertingPartyDetails The {@link Consumer} to apply
@ -1321,7 +1347,7 @@ public final class RelyingPartyRegistration {
return new RelyingPartyRegistration(this.registrationId, this.entityId,
this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding,
this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation,
this.singleLogoutServiceBinding, this.providerDetails.build(), this.credentials,
this.singleLogoutServiceBinding, this.providerDetails.build(), this.nameIdFormat, this.credentials,
this.decryptionX509Credentials, this.signingX509Credentials);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,8 +27,10 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.NameIDPolicy;
import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
import org.opensaml.saml.saml2.core.impl.NameIDPolicyBuilder;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
@ -56,6 +58,8 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
private final IssuerBuilder issuerBuilder;
private final NameIDPolicyBuilder nameIdPolicyBuilder;
private Clock clock = Clock.systemUTC();
private Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter;
@ -69,6 +73,8 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
this.authnRequestBuilder = (AuthnRequestBuilder) registry.getBuilderFactory()
.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME);
this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
this.nameIdPolicyBuilder = (NameIDPolicyBuilder) registry.getBuilderFactory()
.getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME);
}
/**
@ -152,6 +158,9 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
auth.setProtocolBinding(SAMLConstants.SAML2_POST_BINDING_URI);
}
auth.setProtocolBinding(protocolBinding);
if (auth.getNameIDPolicy() == null) {
setNameIdPolicy(auth, context.getRelyingPartyRegistration());
}
Issuer iss = this.issuerBuilder.buildObject();
iss.setValue(issuer);
auth.setIssuer(iss);
@ -160,6 +169,15 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
return auth;
}
private void setNameIdPolicy(AuthnRequest authnRequest, RelyingPartyRegistration registration) {
if (!StringUtils.hasText(registration.getNameIdFormat())) {
return;
}
NameIDPolicy nameIdPolicy = this.nameIdPolicyBuilder.buildObject();
nameIdPolicy.setFormat(registration.getNameIdFormat());
authnRequest.setNameIDPolicy(nameIdPolicy);
}
/**
* Set the strategy for building an {@link AuthnRequest} from a given context
* @param authenticationRequestContextConverter the conversion strategy to use

View File

@ -242,6 +242,18 @@ public class OpenSaml4AuthenticationRequestFactoryTests {
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
}
@Test
public void createAuthenticationRequestWhenSetNameIDPolicyThenReturnsCorrectNameIDPolicy() {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().nameIdFormat("format").build();
this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
.build();
AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);
assertThat(authn.getNameIDPolicy()).isNotNull();
assertThat(authn.getNameIDPolicy().getAllowCreate()).isFalse();
assertThat(authn.getNameIDPolicy().getFormat()).isEqualTo("format");
assertThat(authn.getNameIDPolicy().getSPNameQualifier()).isNull();
}
private AuthnRequest authnRequest() {
AuthnRequest authnRequest = TestOpenSamlObjects.authnRequest();
authnRequest.setIssueInstant(Instant.now());

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -61,4 +61,13 @@ public class OpenSamlMetadataResolverTests {
.contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\"");
}
@Test
public void resolveWhenRelyingPartyNameIDFormatThenMetadataMatches() {
RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full().nameIdFormat("format")
.build();
OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration);
assertThat(metadata).contains("<md:NameIDFormat>format</md:NameIDFormat>");
}
}

View File

@ -28,6 +28,7 @@ public class RelyingPartyRegistrationTests {
@Test
public void withRelyingPartyRegistrationWorks() {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
.nameIdFormat("format")
.assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST))
.assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false))
.assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg")))
@ -74,6 +75,7 @@ public class RelyingPartyRegistrationTests {
.isEqualTo(registration.getAssertingPartyDetails().getVerificationX509Credentials());
assertThat(copy.getAssertingPartyDetails().getSigningAlgorithms())
.isEqualTo(registration.getAssertingPartyDetails().getSigningAlgorithms());
assertThat(copy.getNameIdFormat()).isEqualTo(registration.getNameIdFormat());
}
@Test