diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolver.java new file mode 100644 index 0000000000..c5f2453a88 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolver.java @@ -0,0 +1,182 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.metadata; + +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.UUID; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.OrRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; + +/** + * An implementation of {@link Saml2MetadataResponseResolver} that identifies which + * {@link RelyingPartyRegistration}s to use with a {@link RequestMatcher} + * + * @author Josh Cummings + * @since 6.1 + */ +public final class RequestMatcherMetadataResponseResolver implements Saml2MetadataResponseResolver { + + private static final String DEFAULT_METADATA_FILENAME = "saml-{registrationId}-metadata.xml"; + + private RequestMatcher matcher = new OrRequestMatcher( + new AntPathRequestMatcher("/saml2/service-provider-metadata/{registrationId}"), + new AntPathRequestMatcher("/saml2/metadata/{registrationId}"), + new AntPathRequestMatcher("/saml2/metadata")); + + private String filename = DEFAULT_METADATA_FILENAME; + + private final RelyingPartyRegistrationRepository registrations; + + private final Saml2MetadataResolver metadata; + + /** + * Construct a {@link RequestMatcherMetadataResponseResolver} + * @param registrations the source for relying party metadata + * @param metadata the strategy for converting {@link RelyingPartyRegistration}s into + * metadata + */ + public RequestMatcherMetadataResponseResolver(RelyingPartyRegistrationRepository registrations, + Saml2MetadataResolver metadata) { + Assert.notNull(registrations, "relyingPartyRegistrationRepository cannot be null"); + Assert.notNull(metadata, "saml2MetadataResolver cannot be null"); + this.registrations = registrations; + this.metadata = metadata; + } + + /** + * Construct and serialize a relying party's SAML 2.0 metadata based on the given + * {@link HttpServletRequest}. Uses the configured {@link RequestMatcher} to identify + * the metadata request, including looking for any indicated {@code registrationId}. + * + *

+ * If a {@code registrationId} is found in the request, it will attempt to use that, + * erroring if no {@link RelyingPartyRegistration} is found. + * + *

+ * If no {@code registrationId} is found in the request, it will attempt to show all + * {@link RelyingPartyRegistration}s in an {@code }. To + * exercise this functionality, the provided + * {@link RelyingPartyRegistrationRepository} needs to implement {@link Iterable}. + * @param request the HTTP request + * @return a {@link Saml2MetadataResponse} instance + * @throws Saml2Exception if the {@link RequestMatcher} specifies a non-existent + * {@code registrationId} + */ + @Override + public Saml2MetadataResponse resolve(HttpServletRequest request) { + RequestMatcher.MatchResult result = this.matcher.matcher(request); + if (!result.isMatch()) { + return null; + } + String registrationId = result.getVariables().get("registrationId"); + Saml2MetadataResponse response = responseByRegistrationId(request, registrationId); + if (response != null) { + return response; + } + if (this.registrations instanceof Iterable) { + Iterable registrations = (Iterable) this.registrations; + return responseByIterable(request, registrations); + } + return null; + } + + private Saml2MetadataResponse responseByRegistrationId(HttpServletRequest request, String registrationId) { + if (registrationId == null) { + return null; + } + RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId); + if (registration == null) { + throw new Saml2Exception("registration not found"); + } + return responseByIterable(request, Collections.singleton(registration)); + } + + private Saml2MetadataResponse responseByIterable(HttpServletRequest request, + Iterable registrations) { + Map results = new LinkedHashMap<>(); + for (RelyingPartyRegistration registration : registrations) { + results.put(registration.getEntityId(), registration); + } + Collection resolved = new ArrayList<>(); + for (RelyingPartyRegistration registration : results.values()) { + UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); + String entityId = uriResolver.resolve(registration.getEntityId()); + String ssoLocation = uriResolver.resolve(registration.getAssertionConsumerServiceLocation()); + String sloLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation()); + String sloResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation()); + resolved.add(registration.mutate().entityId(entityId).assertionConsumerServiceLocation(ssoLocation) + .singleLogoutServiceLocation(sloLocation).singleLogoutServiceResponseLocation(sloResponseLocation) + .build()); + } + String metadata = this.metadata.resolve(resolved); + String value = (resolved.size() == 1) ? resolved.iterator().next().getRegistrationId() + : UUID.randomUUID().toString(); + String fileName = this.filename.replace("{registrationId}", value); + try { + String encodedFileName = URLEncoder.encode(fileName, StandardCharsets.UTF_8.name()); + return new Saml2MetadataResponse(metadata, encodedFileName); + } + catch (UnsupportedEncodingException ex) { + throw new Saml2Exception(ex); + } + } + + /** + * Use this {@link RequestMatcher} to identity which requests to generate metadata + * for. By default, matches {@code /saml2/metadata}, + * {@code /saml2/metadata/{registrationId}}, {@code /saml2/service-provider-metadata}, + * and {@code /saml2/service-provider-metadata/{registrationId}} + * @param requestMatcher the {@link RequestMatcher} to use + */ + public void setRequestMatcher(RequestMatcher requestMatcher) { + Assert.notNull(requestMatcher, "requestMatcher cannot be empty"); + this.matcher = requestMatcher; + } + + /** + * Sets the metadata filename template. If it contains the {@code {registrationId}} + * placeholder, it will be resolved as a random UUID if there are multiple + * {@link RelyingPartyRegistration}s. Otherwise, it will be replaced by the + * {@link RelyingPartyRegistration}'s id. + * + *

+ * The default value is {@code saml-{registrationId}-metadata.xml} + * @param metadataFilename metadata filename, must contain a {registrationId} + */ + public void setMetadataFilename(String metadataFilename) { + Assert.hasText(metadataFilename, "metadataFilename cannot be empty"); + this.filename = metadataFilename; + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResponse.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResponse.java new file mode 100644 index 0000000000..508430d735 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResponse.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.metadata; + +public class Saml2MetadataResponse { + + private final String metadata; + + private final String fileName; + + public Saml2MetadataResponse(String metadata, String fileName) { + this.metadata = metadata; + this.fileName = fileName; + } + + public String getMetadata() { + return this.metadata; + } + + public String getFileName() { + return this.fileName; + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResponseResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResponseResolver.java new file mode 100644 index 0000000000..2c3fdcf4a2 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResponseResolver.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.metadata; + +import jakarta.servlet.http.HttpServletRequest; + +/** + * Resolves Relying Party SAML 2.0 Metadata given details from the + * {@link HttpServletRequest}. + * + * @author Josh Cummings + * @since 6.1 + */ +public interface Saml2MetadataResponseResolver { + + /** + * Construct and serialize a relying party's SAML 2.0 metadata based on the given + * {@link HttpServletRequest} + * @param request the HTTP request + * @return a {@link Saml2MetadataResponse} instance + */ + Saml2MetadataResponse resolve(HttpServletRequest request); + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java index cb29a370b0..a513bc1bb8 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java @@ -93,7 +93,6 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo * request. * @param authenticationRequestRepository the * {@link Saml2AuthenticationRequestRepository} to use - * @since 5.6 */ public void setAuthenticationRequestRepository( Saml2AuthenticationRequestRepository authenticationRequestRepository) { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java index 84c7e20332..91c5c00547 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java @@ -27,7 +27,10 @@ import jakarta.servlet.http.HttpServletResponse; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver; +import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResponse; +import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResponseResolver; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; @@ -46,27 +49,20 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { public static final String DEFAULT_METADATA_FILE_NAME = "saml-{registrationId}-metadata.xml"; - private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; - - private final Saml2MetadataResolver saml2MetadataResolver; - - private String metadataFilename = DEFAULT_METADATA_FILE_NAME; - - private RequestMatcher requestMatcher = new AntPathRequestMatcher( - "/saml2/service-provider-metadata/{registrationId}"); + private final Saml2MetadataResponseResolver metadataResolver; public Saml2MetadataFilter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver, Saml2MetadataResolver saml2MetadataResolver) { Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); Assert.notNull(saml2MetadataResolver, "saml2MetadataResolver cannot be null"); - this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; - this.saml2MetadataResolver = saml2MetadataResolver; + this.metadataResolver = new Saml2MetadataResponseResolverAdapter(relyingPartyRegistrationResolver, + saml2MetadataResolver); } /** * Constructs an instance of {@link Saml2MetadataFilter} using the provided - * parameters. The {@link #relyingPartyRegistrationResolver} field will be initialized - * with a {@link DefaultRelyingPartyRegistrationResolver} instance using the provided + * parameters. The {@link #metadataResolver} field will be initialized with a + * {@link DefaultRelyingPartyRegistrationResolver} instance using the provided * {@link RelyingPartyRegistrationRepository} * @param relyingPartyRegistrationRepository the * {@link RelyingPartyRegistrationRepository} to use @@ -78,35 +74,43 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { this(new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository), saml2MetadataResolver); } + /** + * Constructs an instance of {@link Saml2MetadataFilter} + * @param metadataResponseResolver the strategy for producing metadata + * @since 6.1 + */ + public Saml2MetadataFilter(Saml2MetadataResponseResolver metadataResponseResolver) { + this.metadataResolver = metadataResponseResolver; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException { - RequestMatcher.MatchResult matcher = this.requestMatcher.matcher(request); - if (!matcher.isMatch()) { - chain.doFilter(request, response); - return; + Saml2MetadataResponse metadata; + try { + metadata = this.metadataResolver.resolve(request); } - String registrationId = matcher.getVariables().get("registrationId"); - RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request, - registrationId); - if (relyingPartyRegistration == null) { + catch (Saml2Exception ex) { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); return; } - String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration); - writeMetadataToResponse(response, relyingPartyRegistration.getRegistrationId(), metadata); + if (metadata == null) { + chain.doFilter(request, response); + return; + } + writeMetadataToResponse(response, metadata); } - private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata) + private void writeMetadataToResponse(HttpServletResponse response, Saml2MetadataResponse metadata) throws IOException { response.setContentType(MediaType.APPLICATION_XML_VALUE); - String fileName = this.metadataFilename.replace("{registrationId}", registrationId); - String encodedFileName = URLEncoder.encode(fileName, StandardCharsets.UTF_8.name()); String format = "attachment; filename=\"%s\"; filename*=UTF-8''%s"; + String fileName = metadata.getFileName(); + String encodedFileName = URLEncoder.encode(fileName, StandardCharsets.UTF_8.name()); response.setHeader(HttpHeaders.CONTENT_DISPOSITION, String.format(format, fileName, encodedFileName)); - response.setContentLength(metadata.length()); + response.setContentLength(metadata.getMetadata().length()); response.setCharacterEncoding(StandardCharsets.UTF_8.name()); - response.getWriter().write(metadata); + response.getWriter().write(metadata.getMetadata()); } /** @@ -116,7 +120,9 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { */ public void setRequestMatcher(RequestMatcher requestMatcher) { Assert.notNull(requestMatcher, "requestMatcher cannot be null"); - this.requestMatcher = requestMatcher; + Assert.isInstanceOf(Saml2MetadataResponseResolverAdapter.class, this.metadataResolver, + "a Saml2MetadataResponseResolver and RequestMatcher cannot be both set on this filter. Please set the request matcher on the Saml2MetadataResponseResolver itself."); + ((Saml2MetadataResponseResolverAdapter) this.metadataResolver).setRequestMatcher(requestMatcher); } /** @@ -132,7 +138,57 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { Assert.hasText(metadataFilename, "metadataFilename cannot be empty"); Assert.isTrue(metadataFilename.contains("{registrationId}"), "metadataFilename must contain a {registrationId} match variable"); - this.metadataFilename = metadataFilename; + Assert.isInstanceOf(Saml2MetadataResponseResolverAdapter.class, this.metadataResolver, + "a Saml2MetadataResponseResolver and file name cannot be both set on this filter. Please set the file name on the Saml2MetadataResponseResolver itself."); + ((Saml2MetadataResponseResolverAdapter) this.metadataResolver).setMetadataFilename(metadataFilename); + } + + private static final class Saml2MetadataResponseResolverAdapter implements Saml2MetadataResponseResolver { + + private final RelyingPartyRegistrationResolver registrations; + + private RequestMatcher requestMatcher = new AntPathRequestMatcher( + "/saml2/service-provider-metadata/{registrationId}"); + + private final Saml2MetadataResolver metadataResolver; + + private String metadataFilename = DEFAULT_METADATA_FILE_NAME; + + Saml2MetadataResponseResolverAdapter(RelyingPartyRegistrationResolver registrations, + Saml2MetadataResolver metadataResolver) { + this.registrations = registrations; + this.metadataResolver = metadataResolver; + } + + @Override + public Saml2MetadataResponse resolve(HttpServletRequest request) { + RequestMatcher.MatchResult matcher = this.requestMatcher.matcher(request); + if (!matcher.isMatch()) { + return null; + } + String registrationId = matcher.getVariables().get("registrationId"); + RelyingPartyRegistration relyingPartyRegistration = this.registrations.resolve(request, registrationId); + if (relyingPartyRegistration == null) { + throw new Saml2Exception("registration not found"); + } + registrationId = relyingPartyRegistration.getRegistrationId(); + String metadata = this.metadataResolver.resolve(relyingPartyRegistration); + String fileName = this.metadataFilename.replace("{registrationId}", registrationId); + return new Saml2MetadataResponse(metadata, fileName); + } + + void setRequestMatcher(RequestMatcher requestMatcher) { + Assert.notNull(requestMatcher, "requestMatcher cannot be null"); + this.requestMatcher = requestMatcher; + } + + void setMetadataFilename(String metadataFilename) { + Assert.hasText(metadataFilename, "metadataFilename cannot be empty"); + Assert.isTrue(metadataFilename.contains("{registrationId}"), + "metadataFilename must contain a {registrationId} match variable"); + this.metadataFilename = metadataFilename; + } + } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java index 78f5f560b7..b87be7f111 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java @@ -100,8 +100,8 @@ public class OpenSamlMetadataResolverTests { OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver(); String metadata = openSamlMetadataResolver.resolve(List.of(one, two)); assertThat(metadata).contains("") - .contains("") + .contains("entityID=\"rp-entity-id\"").contains("entityID=\"two\"") + .contains("").contains("") .contains("MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh") .contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"") .contains("Location=\"https://rp.example.org/acs\" index=\"1\"") diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java new file mode 100644 index 0000000000..af98378f18 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.metadata; + +import java.util.Collection; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +@ExtendWith(MockitoExtension.class) +public final class RequestMatcherMetadataResponseResolverTests { + + @Mock + Saml2MetadataResolver metadataFactory; + + @Test + void saml2MetadataRegistrationIdResolveWhenMatchesThenResolves() { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); + RelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(registration); + RequestMatcherMetadataResponseResolver resolver = new RequestMatcherMetadataResponseResolver(registrations, + this.metadataFactory); + String registrationId = registration.getRegistrationId(); + given(this.metadataFactory.resolve(any(Collection.class))).willReturn("metadata"); + MockHttpServletRequest request = get("/saml2/metadata/" + registrationId); + Saml2MetadataResponse response = resolver.resolve(request); + assertThat(response.getMetadata()).isEqualTo("metadata"); + assertThat(response.getFileName()).isEqualTo("saml-" + registrationId + "-metadata.xml"); + verify(this.metadataFactory).resolve(any(Collection.class)); + } + + @Test + void saml2MetadataResolveWhenNoMatchingRegistrationThenNull() { + RelyingPartyRegistrationRepository registrations = mock(RelyingPartyRegistrationRepository.class); + RequestMatcherMetadataResponseResolver resolver = new RequestMatcherMetadataResponseResolver(registrations, + this.metadataFactory); + MockHttpServletRequest request = get("/saml2/metadata"); + Saml2MetadataResponse response = resolver.resolve(request); + assertThat(response).isNull(); + } + + @Test + void saml2MetadataRegistrationIdResolveWhenNoMatchingRegistrationThenException() { + RelyingPartyRegistrationRepository registrations = mock(RelyingPartyRegistrationRepository.class); + RequestMatcherMetadataResponseResolver resolver = new RequestMatcherMetadataResponseResolver(registrations, + this.metadataFactory); + MockHttpServletRequest request = get("/saml2/metadata/id"); + assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> resolver.resolve(request)); + } + + @Test + void resolveWhenNoRegistrationIdThenResolvesAll() { + RelyingPartyRegistration one = withEntityId("one"); + RelyingPartyRegistration two = withEntityId("two"); + RelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(one, two); + RequestMatcherMetadataResponseResolver resolver = new RequestMatcherMetadataResponseResolver(registrations, + this.metadataFactory); + given(this.metadataFactory.resolve(any(Collection.class))).willReturn("metadata"); + MockHttpServletRequest request = get("/saml2/metadata"); + Saml2MetadataResponse response = resolver.resolve(request); + assertThat(response.getMetadata()).isEqualTo("metadata"); + assertThat(response.getFileName()).doesNotContain(one.getRegistrationId()).contains("saml") + .contains("metadata.xml"); + verify(this.metadataFactory).resolve(any(Collection.class)); + } + + @Test + void resolveWhenRequestDoesNotMatchThenNull() { + RelyingPartyRegistrationRepository registrations = mock(RelyingPartyRegistrationRepository.class); + RequestMatcherMetadataResponseResolver resolver = new RequestMatcherMetadataResponseResolver(registrations, + this.metadataFactory); + assertThat(resolver.resolve(new MockHttpServletRequest())).isNull(); + } + + private MockHttpServletRequest get(String uri) { + MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); + request.setServletPath(uri); + return request; + } + + private RelyingPartyRegistration withEntityId(String entityId) { + return TestRelyingPartyRegistrations.relyingPartyRegistration().registrationId(entityId).entityId(entityId) + .build(); + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java index 3c5771e86e..043b05c59a 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java @@ -33,7 +33,6 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; -import org.springframework.test.util.ReflectionTestUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -141,7 +140,7 @@ public class Saml2MetadataFilterTests { public void doFilterWhenResolverConstructorAndPathStartsWithRegistrationIdThenServesMetadata() throws Exception { RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); given(this.repository.findByRegistrationId("registration-id")).willReturn(registration); - given(this.resolver.resolve(any())).willReturn("metadata"); + given(this.resolver.resolve(any(RelyingPartyRegistration.class))).willReturn("metadata"); RelyingPartyRegistrationResolver resolver = new DefaultRelyingPartyRegistrationResolver( (id) -> this.repository.findByRegistrationId("registration-id")); this.filter = new Saml2MetadataFilter(resolver, this.resolver); @@ -156,7 +155,7 @@ public class Saml2MetadataFilterTests { throws Exception { RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); given(this.repository.findByRegistrationId("registration-id")).willReturn(registration); - given(this.resolver.resolve(any())).willReturn("metadata"); + given(this.resolver.resolve(any(RelyingPartyRegistration.class))).willReturn("metadata"); this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("registration-id"), this.resolver); this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata")); @@ -199,12 +198,11 @@ public class Saml2MetadataFilterTests { } @Test - public void constructorWhenRelyingPartyRegistrationRepositoryThenUses() { + public void constructorWhenRelyingPartyRegistrationRepositoryThenUses() throws Exception { RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); this.filter = new Saml2MetadataFilter(repository, this.resolver); - DefaultRelyingPartyRegistrationResolver relyingPartyRegistrationResolver = (DefaultRelyingPartyRegistrationResolver) ReflectionTestUtils - .getField(this.filter, "relyingPartyRegistrationResolver"); - relyingPartyRegistrationResolver.resolve(this.request, "one"); + this.request.setPathInfo("/saml2/service-provider-metadata/one"); + this.filter.doFilter(this.request, this.response, this.chain); verify(repository).findByRegistrationId("one"); }