Merge branch '6.1.x'

Closes gh-13701
This commit is contained in:
Josh Cummings 2023-08-18 14:36:45 -06:00
commit 3540dee259
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
2 changed files with 30 additions and 16 deletions

View File

@ -19,8 +19,6 @@ package org.springframework.security.saml2.provider.service.metadata;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
@ -126,21 +124,19 @@ public final class RequestMatcherMetadataResponseResolver implements Saml2Metada
Iterable<RelyingPartyRegistration> registrations) {
Map<String, RelyingPartyRegistration> results = new LinkedHashMap<>();
for (RelyingPartyRegistration registration : registrations) {
results.put(registration.getEntityId(), registration);
}
Collection<RelyingPartyRegistration> resolved = new ArrayList<>();
for (RelyingPartyRegistration registration : results.values()) {
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
String entityId = uriResolver.resolve(registration.getEntityId());
String ssoLocation = uriResolver.resolve(registration.getAssertionConsumerServiceLocation());
String sloLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
String sloResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
resolved.add(registration.mutate().entityId(entityId).assertionConsumerServiceLocation(ssoLocation)
.singleLogoutServiceLocation(sloLocation).singleLogoutServiceResponseLocation(sloResponseLocation)
.build());
results.computeIfAbsent(entityId, (e) -> {
String ssoLocation = uriResolver.resolve(registration.getAssertionConsumerServiceLocation());
String sloLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
String sloResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
return registration.mutate().entityId(entityId).assertionConsumerServiceLocation(ssoLocation)
.singleLogoutServiceLocation(sloLocation)
.singleLogoutServiceResponseLocation(sloResponseLocation).build();
});
}
String metadata = this.metadata.resolve(resolved);
String value = (resolved.size() == 1) ? resolved.iterator().next().getRegistrationId()
String metadata = this.metadata.resolve(results.values());
String value = (results.size() == 1) ? results.values().iterator().next().getRegistrationId()
: UUID.randomUUID().toString();
String fileName = this.filename.replace("{registrationId}", value);
try {

View File

@ -20,6 +20,7 @@ import java.util.Collection;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
@ -101,6 +102,23 @@ public final class RequestMatcherMetadataResponseResolverTests {
assertThat(resolver.resolve(new MockHttpServletRequest())).isNull();
}
// gh-13700
@Test
void resolveWhenNoRegistrationIdThenResolvesEntityIds() {
RelyingPartyRegistration one = withEntityId("one");
RelyingPartyRegistration two = withEntityId("two");
RelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(one, two);
RequestMatcherMetadataResponseResolver resolver = new RequestMatcherMetadataResponseResolver(registrations,
this.metadataFactory);
given(this.metadataFactory.resolve(any(Collection.class))).willReturn("metadata");
resolver.resolve(get("/saml2/metadata"));
ArgumentCaptor<Collection<RelyingPartyRegistration>> captor = ArgumentCaptor.forClass(Collection.class);
verify(this.metadataFactory).resolve(captor.capture());
Collection<RelyingPartyRegistration> resolved = captor.getValue();
assertThat(resolved).hasSize(2);
assertThat(resolved.iterator().next().getEntityId()).isEqualTo("one");
}
private MockHttpServletRequest get(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", uri);
request.setServletPath(uri);
@ -108,8 +126,8 @@ public final class RequestMatcherMetadataResponseResolverTests {
}
private RelyingPartyRegistration withEntityId(String entityId) {
return TestRelyingPartyRegistrations.relyingPartyRegistration().registrationId(entityId).entityId(entityId)
.build();
return TestRelyingPartyRegistrations.relyingPartyRegistration().registrationId(entityId)
.entityId("{registrationId}").build();
}
}