Add a new Saml2MetadataFilter constructor for RelyingPartyRegistrationRepository

Closes gh-11815
This commit is contained in:
Mitja Kotnik 2022-12-02 20:26:59 +01:00 committed by Marcus Da Coregio
parent 7561a02cdd
commit 70249e536a
2 changed files with 20 additions and 3 deletions

View File

@ -29,6 +29,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
@ -62,6 +63,11 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
this.saml2MetadataResolver = saml2MetadataResolver;
}
public Saml2MetadataFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
Saml2MetadataResolver saml2MetadataResolver) {
this(new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository), saml2MetadataResolver);
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {

View File

@ -64,9 +64,7 @@ public class Saml2MetadataFilterTests {
public void setup() {
this.repository = mock(RelyingPartyRegistrationRepository.class);
this.resolver = mock(Saml2MetadataResolver.class);
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(
this.repository);
this.filter = new Saml2MetadataFilter(relyingPartyRegistrationResolver, this.resolver);
this.filter = new Saml2MetadataFilter(this.repository, this.resolver);
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.chain = mock(FilterChain.class);
@ -152,6 +150,19 @@ public class Saml2MetadataFilterTests {
verify(this.repository).findByRegistrationId("registration-id");
}
@Test
public void doFilterWhenPathStartsWithOneThenServesMetadata() throws Exception {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
given(this.repository.findByRegistrationId("one")).willReturn(registration);
given(this.resolver.resolve(any())).willReturn("metadata");
this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("one"),
this.resolver);
this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata"));
this.request.setPathInfo("/metadata");
this.filter.doFilter(this.request, this.response, new MockFilterChain());
verify(this.repository).findByRegistrationId("one");
}
// gh-12026
@Test
public void doFilterWhenCharacterEncodingThenEncodeSpecialCharactersCorrectly() throws Exception {