From 27294b2e11b39044cad7627f4b7247e54bea5302 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 25 Oct 2024 16:11:39 -0600 Subject: [PATCH] Allow RelyingPartyRegistration Placeholder Resolution in XML Closes gh-14645 --- ...artyRegistrationsBeanDefinitionParser.java | 28 +++++++++------ ...egistrationsBeanDefinitionParserTests.java | 27 ++++++++++++++ ...ionParserTests-PlaceholderRegistration.xml | 35 +++++++++++++++++++ 3 files changed, 79 insertions(+), 11 deletions(-) create mode 100644 config/src/test/resources/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests-PlaceholderRegistration.xml diff --git a/config/src/main/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParser.java index 30274206b5..a8a10122bc 100644 --- a/config/src/main/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParser.java @@ -213,8 +213,10 @@ public final class RelyingPartyRegistrationsBeanDefinitionParser implements Bean private static RelyingPartyRegistration.Builder getBuilderFromMetadataLocationIfPossible( Element relyingPartyRegistrationElt, Map> assertingParties, ParserContext parserContext) { - String registrationId = relyingPartyRegistrationElt.getAttribute(ATT_REGISTRATION_ID); - String metadataLocation = relyingPartyRegistrationElt.getAttribute(ATT_METADATA_LOCATION); + String registrationId = resolveAttribute(parserContext, + relyingPartyRegistrationElt.getAttribute(ATT_REGISTRATION_ID)); + String metadataLocation = resolveAttribute(parserContext, + relyingPartyRegistrationElt.getAttribute(ATT_METADATA_LOCATION)); RelyingPartyRegistration.Builder builder; if (StringUtils.hasText(metadataLocation)) { builder = RelyingPartyRegistrations.fromMetadataLocation(metadataLocation).registrationId(registrationId); @@ -224,20 +226,20 @@ public final class RelyingPartyRegistrationsBeanDefinitionParser implements Bean .assertingPartyMetadata((apBuilder) -> buildAssertingParty(relyingPartyRegistrationElt, assertingParties, apBuilder, parserContext)); } - addRemainingProperties(relyingPartyRegistrationElt, builder); + addRemainingProperties(parserContext, relyingPartyRegistrationElt, builder); return builder; } - private static void addRemainingProperties(Element relyingPartyRegistrationElt, + private static void addRemainingProperties(ParserContext pc, Element relyingPartyRegistrationElt, RelyingPartyRegistration.Builder builder) { - String entityId = relyingPartyRegistrationElt.getAttribute(ATT_ENTITY_ID); - String singleLogoutServiceLocation = relyingPartyRegistrationElt - .getAttribute(ATT_SINGLE_LOGOUT_SERVICE_LOCATION); - String singleLogoutServiceResponseLocation = relyingPartyRegistrationElt - .getAttribute(ATT_SINGLE_LOGOUT_SERVICE_RESPONSE_LOCATION); + String entityId = resolveAttribute(pc, relyingPartyRegistrationElt.getAttribute(ATT_ENTITY_ID)); + String singleLogoutServiceLocation = resolveAttribute(pc, + relyingPartyRegistrationElt.getAttribute(ATT_SINGLE_LOGOUT_SERVICE_LOCATION)); + String singleLogoutServiceResponseLocation = resolveAttribute(pc, + relyingPartyRegistrationElt.getAttribute(ATT_SINGLE_LOGOUT_SERVICE_RESPONSE_LOCATION)); Saml2MessageBinding singleLogoutServiceBinding = getSingleLogoutServiceBinding(relyingPartyRegistrationElt); - String assertionConsumerServiceLocation = relyingPartyRegistrationElt - .getAttribute(ATT_ASSERTION_CONSUMER_SERVICE_LOCATION); + String assertionConsumerServiceLocation = resolveAttribute(pc, + relyingPartyRegistrationElt.getAttribute(ATT_ASSERTION_CONSUMER_SERVICE_LOCATION)); Saml2MessageBinding assertionConsumerServiceBinding = getAssertionConsumerServiceBinding( relyingPartyRegistrationElt); if (StringUtils.hasText(entityId)) { @@ -400,4 +402,8 @@ public final class RelyingPartyRegistrationsBeanDefinitionParser implements Bean } } + private static String resolveAttribute(ParserContext pc, String value) { + return pc.getReaderContext().getEnvironment().resolvePlaceholders(value); + } + } diff --git a/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java index 5a332582d3..68a6c22ab1 100644 --- a/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java @@ -35,6 +35,7 @@ import org.springframework.security.saml2.provider.service.registration.InMemory import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver; import static org.assertj.core.api.Assertions.assertThat; @@ -288,6 +289,32 @@ public class RelyingPartyRegistrationsBeanDefinitionParserTests { verify(relayStateResolver).convert(request); } + @Test + public void parseWhenPlaceholdersThenResolves() throws Exception { + RelyingPartyRegistration sample = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); + System.setProperty("registration-id", sample.getRegistrationId()); + System.setProperty("entity-id", sample.getEntityId()); + System.setProperty("acs-location", sample.getAssertionConsumerServiceLocation()); + System.setProperty("slo-location", sample.getSingleLogoutServiceLocation()); + System.setProperty("slo-response-location", sample.getSingleLogoutServiceResponseLocation()); + try (MockWebServer web = new MockWebServer()) { + web.start(); + String serverUrl = web.url("/metadata").toString(); + web.enqueue(xmlResponse(METADATA_RESPONSE)); + System.setProperty("metadata-location", serverUrl); + this.spring.configLocations(xml("PlaceholderRegistration")).autowire(); + } + RelyingPartyRegistration registration = this.relyingPartyRegistrationRepository + .findByRegistrationId(sample.getRegistrationId()); + assertThat(registration.getRegistrationId()).isEqualTo(sample.getRegistrationId()); + assertThat(registration.getEntityId()).isEqualTo(sample.getEntityId()); + assertThat(registration.getAssertionConsumerServiceLocation()) + .isEqualTo(sample.getAssertionConsumerServiceLocation()); + assertThat(registration.getSingleLogoutServiceLocation()).isEqualTo(sample.getSingleLogoutServiceLocation()); + assertThat(registration.getSingleLogoutServiceResponseLocation()) + .isEqualTo(sample.getSingleLogoutServiceResponseLocation()); + } + private static MockResponse xmlResponse(String xml) { return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_XML_VALUE).setBody(xml); } diff --git a/config/src/test/resources/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests-PlaceholderRegistration.xml b/config/src/test/resources/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests-PlaceholderRegistration.xml new file mode 100644 index 0000000000..f44ca8fe3a --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests-PlaceholderRegistration.xml @@ -0,0 +1,35 @@ + + + + + + + + +