diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolver.java index c5f2453a88..84af43398e 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolver.java @@ -19,8 +19,6 @@ package org.springframework.security.saml2.provider.service.metadata; import java.io.UnsupportedEncodingException; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; @@ -126,21 +124,19 @@ public final class RequestMatcherMetadataResponseResolver implements Saml2Metada Iterable registrations) { Map results = new LinkedHashMap<>(); for (RelyingPartyRegistration registration : registrations) { - results.put(registration.getEntityId(), registration); - } - Collection resolved = new ArrayList<>(); - for (RelyingPartyRegistration registration : results.values()) { UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); String entityId = uriResolver.resolve(registration.getEntityId()); - String ssoLocation = uriResolver.resolve(registration.getAssertionConsumerServiceLocation()); - String sloLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation()); - String sloResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation()); - resolved.add(registration.mutate().entityId(entityId).assertionConsumerServiceLocation(ssoLocation) - .singleLogoutServiceLocation(sloLocation).singleLogoutServiceResponseLocation(sloResponseLocation) - .build()); + results.computeIfAbsent(entityId, (e) -> { + String ssoLocation = uriResolver.resolve(registration.getAssertionConsumerServiceLocation()); + String sloLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation()); + String sloResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation()); + return registration.mutate().entityId(entityId).assertionConsumerServiceLocation(ssoLocation) + .singleLogoutServiceLocation(sloLocation) + .singleLogoutServiceResponseLocation(sloResponseLocation).build(); + }); } - String metadata = this.metadata.resolve(resolved); - String value = (resolved.size() == 1) ? resolved.iterator().next().getRegistrationId() + String metadata = this.metadata.resolve(results.values()); + String value = (results.size() == 1) ? results.values().iterator().next().getRegistrationId() : UUID.randomUUID().toString(); String fileName = this.filename.replace("{registrationId}", value); try { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java index af98378f18..32bccdbdac 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java @@ -20,6 +20,7 @@ import java.util.Collection; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @@ -101,6 +102,23 @@ public final class RequestMatcherMetadataResponseResolverTests { assertThat(resolver.resolve(new MockHttpServletRequest())).isNull(); } + // gh-13700 + @Test + void resolveWhenNoRegistrationIdThenResolvesEntityIds() { + RelyingPartyRegistration one = withEntityId("one"); + RelyingPartyRegistration two = withEntityId("two"); + RelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(one, two); + RequestMatcherMetadataResponseResolver resolver = new RequestMatcherMetadataResponseResolver(registrations, + this.metadataFactory); + given(this.metadataFactory.resolve(any(Collection.class))).willReturn("metadata"); + resolver.resolve(get("/saml2/metadata")); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Collection.class); + verify(this.metadataFactory).resolve(captor.capture()); + Collection resolved = captor.getValue(); + assertThat(resolved).hasSize(2); + assertThat(resolved.iterator().next().getEntityId()).isEqualTo("one"); + } + private MockHttpServletRequest get(String uri) { MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); request.setServletPath(uri); @@ -108,8 +126,8 @@ public final class RequestMatcherMetadataResponseResolverTests { } private RelyingPartyRegistration withEntityId(String entityId) { - return TestRelyingPartyRegistrations.relyingPartyRegistration().registrationId(entityId).entityId(entityId) - .build(); + return TestRelyingPartyRegistrations.relyingPartyRegistration().registrationId(entityId) + .entityId("{registrationId}").build(); } }