From 1c3ce1e401e13514682aeba9262079ebc6800b2a Mon Sep 17 00:00:00 2001 From: Marcus Da Coregio Date: Thu, 22 Dec 2022 11:19:37 -0300 Subject: [PATCH] Fix entity-id ignored in RelyingPartyRegistration XML config Closes gh-11898 --- ...artyRegistrationsBeanDefinitionParser.java | 55 ++++++++++++------ ...egistrationsBeanDefinitionParserTests.java | 58 ++++++++++++++++++- 2 files changed, 94 insertions(+), 19 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParser.java index ab55ad0df8..7026b7e238 100644 --- a/config/src/main/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParser.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -208,30 +208,49 @@ public final class RelyingPartyRegistrationsBeanDefinitionParser implements Bean ParserContext parserContext) { String registrationId = relyingPartyRegistrationElt.getAttribute(ATT_REGISTRATION_ID); String metadataLocation = relyingPartyRegistrationElt.getAttribute(ATT_METADATA_LOCATION); + RelyingPartyRegistration.Builder builder; + if (StringUtils.hasText(metadataLocation)) { + builder = RelyingPartyRegistrations.fromMetadataLocation(metadataLocation).registrationId(registrationId); + } + else { + builder = RelyingPartyRegistration.withRegistrationId(registrationId) + .assertingPartyDetails((apBuilder) -> buildAssertingParty(relyingPartyRegistrationElt, + assertingParties, apBuilder, parserContext)); + } + addRemainingProperties(relyingPartyRegistrationElt, builder); + return builder; + } + + private static void addRemainingProperties(Element relyingPartyRegistrationElt, + RelyingPartyRegistration.Builder builder) { + String entityId = relyingPartyRegistrationElt.getAttribute(ATT_ENTITY_ID); String singleLogoutServiceLocation = relyingPartyRegistrationElt .getAttribute(ATT_SINGLE_LOGOUT_SERVICE_LOCATION); String singleLogoutServiceResponseLocation = relyingPartyRegistrationElt .getAttribute(ATT_SINGLE_LOGOUT_SERVICE_RESPONSE_LOCATION); Saml2MessageBinding singleLogoutServiceBinding = getSingleLogoutServiceBinding(relyingPartyRegistrationElt); - if (StringUtils.hasText(metadataLocation)) { - return RelyingPartyRegistrations.fromMetadataLocation(metadataLocation).registrationId(registrationId) - .singleLogoutServiceLocation(singleLogoutServiceLocation) - .singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation) - .singleLogoutServiceBinding(singleLogoutServiceBinding); - } - String entityId = relyingPartyRegistrationElt.getAttribute(ATT_ENTITY_ID); String assertionConsumerServiceLocation = relyingPartyRegistrationElt .getAttribute(ATT_ASSERTION_CONSUMER_SERVICE_LOCATION); Saml2MessageBinding assertionConsumerServiceBinding = getAssertionConsumerServiceBinding( relyingPartyRegistrationElt); - return RelyingPartyRegistration.withRegistrationId(registrationId).entityId(entityId) - .assertionConsumerServiceLocation(assertionConsumerServiceLocation) - .assertionConsumerServiceBinding(assertionConsumerServiceBinding) - .singleLogoutServiceLocation(singleLogoutServiceLocation) - .singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation) - .singleLogoutServiceBinding(singleLogoutServiceBinding) - .assertingPartyDetails((builder) -> buildAssertingParty(relyingPartyRegistrationElt, assertingParties, - builder, parserContext)); + if (StringUtils.hasText(entityId)) { + builder.entityId(entityId); + } + if (StringUtils.hasText(singleLogoutServiceLocation)) { + builder.singleLogoutServiceLocation(singleLogoutServiceLocation); + } + if (StringUtils.hasText(singleLogoutServiceResponseLocation)) { + builder.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation); + } + if (singleLogoutServiceBinding != null) { + builder.singleLogoutServiceBinding(singleLogoutServiceBinding); + } + if (StringUtils.hasText(assertionConsumerServiceLocation)) { + builder.assertionConsumerServiceLocation(assertionConsumerServiceLocation); + } + if (assertionConsumerServiceBinding != null) { + builder.assertionConsumerServiceBinding(assertionConsumerServiceBinding); + } } private static void buildAssertingParty(Element relyingPartyElt, Map> assertingParties, @@ -309,7 +328,7 @@ public final class RelyingPartyRegistrationsBeanDefinitionParser implements Bean if (StringUtils.hasText(assertionConsumerServiceBinding)) { return Saml2MessageBinding.valueOf(assertionConsumerServiceBinding); } - return Saml2MessageBinding.REDIRECT; + return null; } private static Saml2MessageBinding getSingleLogoutServiceBinding(Element relyingPartyRegistrationElt) { @@ -317,7 +336,7 @@ public final class RelyingPartyRegistrationsBeanDefinitionParser implements Bean if (StringUtils.hasText(singleLogoutServiceBinding)) { return Saml2MessageBinding.valueOf(singleLogoutServiceBinding); } - return Saml2MessageBinding.POST; + return null; } private static Saml2X509Credential getSaml2VerificationCredential(String certificateLocation) { diff --git a/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java index f90a2be1b5..b891b31f32 100644 --- a/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,6 +62,27 @@ public class RelyingPartyRegistrationsBeanDefinitionParserTests { "\n"; // @formatter:on + // @formatter:off + private static final String METADATA_LOCATION_OVERRIDE_PROPERTIES_XML_CONFIG = "\n" + + " \n" + + " \n" + + " " + + " \n" + + "\n" + + "\n"; + // @formatter:on + // @formatter:off private static final String METADATA_RESPONSE = "\n" + "\n" + @@ -143,6 +164,41 @@ public class RelyingPartyRegistrationsBeanDefinitionParserTests { .containsExactly("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"); } + @Test + public void parseWhenMetadataLocationConfiguredAndRegistrationHasPropertiesThenDoNotOverrideSpecifiedProperties() + throws Exception { + this.server = new MockWebServer(); + this.server.start(); + String serverUrl = this.server.url("/").toString(); + this.server.enqueue(xmlResponse(METADATA_RESPONSE)); + String metadataConfig = METADATA_LOCATION_OVERRIDE_PROPERTIES_XML_CONFIG.replace("${metadata-location}", + serverUrl); + this.spring.context(metadataConfig).autowire(); + assertThat(this.relyingPartyRegistrationRepository) + .isInstanceOf(InMemoryRelyingPartyRegistrationRepository.class); + RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository + .findByRegistrationId("one"); + RelyingPartyRegistration.AssertingPartyDetails assertingPartyDetails = relyingPartyRegistration + .getAssertingPartyDetails(); + assertThat(relyingPartyRegistration).isNotNull(); + assertThat(relyingPartyRegistration.getRegistrationId()).isEqualTo("one"); + assertThat(relyingPartyRegistration.getEntityId()).isEqualTo("https://rp.example.org"); + assertThat(relyingPartyRegistration.getAssertionConsumerServiceLocation()) + .isEqualTo("https://rp.example.org/location"); + assertThat(relyingPartyRegistration.getAssertionConsumerServiceBinding()) + .isEqualTo(Saml2MessageBinding.REDIRECT); + assertThat(assertingPartyDetails.getEntityId()) + .isEqualTo("https://simplesaml-for-spring-saml.apps.pcfone.io/saml2/idp/metadata.php"); + assertThat(assertingPartyDetails.getWantAuthnRequestsSigned()).isFalse(); + assertThat(assertingPartyDetails.getVerificationX509Credentials()).hasSize(1); + assertThat(assertingPartyDetails.getEncryptionX509Credentials()).hasSize(1); + assertThat(assertingPartyDetails.getSingleSignOnServiceLocation()) + .isEqualTo("https://simplesaml-for-spring-saml.apps.pcfone.io/saml2/idp/SSOService.php"); + assertThat(assertingPartyDetails.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + assertThat(assertingPartyDetails.getSigningAlgorithms()) + .containsExactly("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"); + } + @Test public void parseWhenSingleRelyingPartyRegistrationThenAvailableInRepository() { this.spring.configLocations(xml("SingleRegistration")).autowire();