diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java index 739b4f75de..1e7f120db1 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.Base64; import java.util.Collection; import java.util.List; +import java.util.function.Consumer; import javax.xml.namespace.QName; @@ -63,6 +64,9 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver { private final EntityDescriptorMarshaller entityDescriptorMarshaller; + private Consumer entityDescriptorCustomizer = (parameters) -> { + }; + public OpenSamlMetadataResolver() { this.entityDescriptorMarshaller = (EntityDescriptorMarshaller) XMLObjectProviderRegistrySupport .getMarshallerFactory().getMarshaller(EntityDescriptor.DEFAULT_ELEMENT_NAME); @@ -75,9 +79,22 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver { entityDescriptor.setEntityID(relyingPartyRegistration.getEntityId()); SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(relyingPartyRegistration); entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME).add(spSsoDescriptor); + this.entityDescriptorCustomizer + .accept(new EntityDescriptorParameters(entityDescriptor, relyingPartyRegistration)); return serialize(entityDescriptor); } + /** + * Set a {@link Consumer} for modifying the OpenSAML {@link EntityDescriptor} + * @param entityDescriptorCustomizer a consumer that accepts an + * {@link EntityDescriptorParameters} + * @since 5.7 + */ + public void setEntityDescriptorCustomizer(Consumer entityDescriptorCustomizer) { + Assert.notNull(entityDescriptorCustomizer, "entityDescriptorCustomizer cannot be null"); + this.entityDescriptorCustomizer = entityDescriptorCustomizer; + } + private SPSSODescriptor buildSpSsoDescriptor(RelyingPartyRegistration registration) { SPSSODescriptor spSsoDescriptor = build(SPSSODescriptor.DEFAULT_ELEMENT_NAME); spSsoDescriptor.addSupportedProtocol(SAMLConstants.SAML20P_NS); @@ -163,4 +180,25 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver { } } + public static final class EntityDescriptorParameters { + + private final EntityDescriptor entityDescriptor; + + private final RelyingPartyRegistration registration; + + public EntityDescriptorParameters(EntityDescriptor entityDescriptor, RelyingPartyRegistration registration) { + this.entityDescriptor = entityDescriptor; + this.registration = registration; + } + + public EntityDescriptor getEntityDescriptor() { + return this.entityDescriptor; + } + + public RelyingPartyRegistration getRegistration() { + return this.registration; + } + + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java index f67cfcafc2..2f7cd17143 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,4 +78,15 @@ public class OpenSamlMetadataResolverTests { assertThat(metadata).doesNotContain("ResponseLocation"); } + @Test + public void resolveWhenEntityDescriptorCustomizerThenUses() { + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full() + .entityId("originalEntityId").build(); + OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver(); + openSamlMetadataResolver.setEntityDescriptorCustomizer( + (parameters) -> parameters.getEntityDescriptor().setEntityID("overriddenEntityId")); + String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration); + assertThat(metadata).contains("