Issue gh-11815
This commit is contained in:
Marcus Da Coregio 2022-12-05 10:16:08 -08:00
parent 70249e536a
commit 369bc71c81
2 changed files with 27 additions and 5 deletions

View File

@ -63,6 +63,16 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
this.saml2MetadataResolver = saml2MetadataResolver; this.saml2MetadataResolver = saml2MetadataResolver;
} }
/**
* Constructs an instance of {@link Saml2MetadataFilter} using the provided
* parameters. The {@link #relyingPartyRegistrationResolver} field will be initialized
* with a {@link DefaultRelyingPartyRegistrationResolver} instance using the provided
* {@link RelyingPartyRegistrationRepository}
* @param relyingPartyRegistrationRepository the
* {@link RelyingPartyRegistrationRepository} to use
* @param saml2MetadataResolver the {@link Saml2MetadataResolver} to use
* @since 6.1
*/
public Saml2MetadataFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, public Saml2MetadataFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
Saml2MetadataResolver saml2MetadataResolver) { Saml2MetadataResolver saml2MetadataResolver) {
this(new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository), saml2MetadataResolver); this(new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository), saml2MetadataResolver);

View File

@ -33,6 +33,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -137,7 +138,7 @@ public class Saml2MetadataFilterTests {
} }
@Test @Test
public void doFilterWhenPathStartsWithRegistrationIdThenServesMetadata() throws Exception { public void doFilterWhenResolverConstructorAndPathStartsWithRegistrationIdThenServesMetadata() throws Exception {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
given(this.repository.findByRegistrationId("registration-id")).willReturn(registration); given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
given(this.resolver.resolve(any())).willReturn("metadata"); given(this.resolver.resolve(any())).willReturn("metadata");
@ -151,16 +152,17 @@ public class Saml2MetadataFilterTests {
} }
@Test @Test
public void doFilterWhenPathStartsWithOneThenServesMetadata() throws Exception { public void doFilterWhenRelyingPartyRegistrationRepositoryConstructorAndPathStartsWithRegistrationIdThenServesMetadata()
throws Exception {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
given(this.repository.findByRegistrationId("one")).willReturn(registration); given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
given(this.resolver.resolve(any())).willReturn("metadata"); given(this.resolver.resolve(any())).willReturn("metadata");
this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("one"), this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("registration-id"),
this.resolver); this.resolver);
this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata")); this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata"));
this.request.setPathInfo("/metadata"); this.request.setPathInfo("/metadata");
this.filter.doFilter(this.request, this.response, new MockFilterChain()); this.filter.doFilter(this.request, this.response, new MockFilterChain());
verify(this.repository).findByRegistrationId("one"); verify(this.repository).findByRegistrationId("registration-id");
} }
// gh-12026 // gh-12026
@ -196,4 +198,14 @@ public class Saml2MetadataFilterTests {
.withMessage("metadataFilename must contain a {registrationId} match variable"); .withMessage("metadataFilename must contain a {registrationId} match variable");
} }
@Test
public void constructorWhenRelyingPartyRegistrationRepositoryThenUses() {
RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
this.filter = new Saml2MetadataFilter(repository, this.resolver);
DefaultRelyingPartyRegistrationResolver relyingPartyRegistrationResolver = (DefaultRelyingPartyRegistrationResolver) ReflectionTestUtils
.getField(this.filter, "relyingPartyRegistrationResolver");
relyingPartyRegistrationResolver.resolve(this.request, "one");
verify(repository).findByRegistrationId("one");
}
} }