Add RelyingPartyRegistrationResolver

Closes gh-9486
This commit is contained in:
Josh Cummings 2021-03-02 07:55:05 -07:00
parent efe42b93ce
commit 2f734a0975
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
7 changed files with 172 additions and 33 deletions

View File

@ -555,19 +555,24 @@ There are a number of reasons you may want to customize. Among them:
* You may know that you will never be a multi-tenant application and so want to have a simpler URL scheme * You may know that you will never be a multi-tenant application and so want to have a simpler URL scheme
* You may identify tenants in a way other than by the URI path * You may identify tenants in a way other than by the URI path
To customize the way that a `RelyingPartyRegistration` is resolved, you can configure a custom `Converter<HttpServletRequest, RelyingPartyRegistration>`. To customize the way that a `RelyingPartyRegistration` is resolved, you can configure a custom `RelyingPartyRegistrationResolver`.
The default looks up the registration id from the URI's last path element and looks it up in your `RelyingPartyRegistrationRepository`. The default looks up the registration id from the URI's last path element and looks it up in your `RelyingPartyRegistrationRepository`.
You can provide a simpler resolver that, for example, always returns the same relying party: You can provide a simpler resolver that, for example, always returns the same relying party:
[source,java] [source,java]
---- ----
public class SingleRelyingPartyRegistrationResolver public class SingleRelyingPartyRegistrationResolver implements RelyingPartyRegistrationResolver {
implements Converter<HttpServletRequest, RelyingPartyRegistration> {
private final RelyingPartyRegistrationResolver delegate;
public SingleRelyingPartyRegistrationResolver(RelyingPartyRegistrationRepository registrations) {
this.delegate = new DefaultRelyingPartyRegistrationResolver(registrations);
}
@Override @Override
public RelyingPartyRegistration convert(HttpServletRequest request) { public RelyingPartyRegistration resolve(HttpServletRequest request, String registrationId) {
return this.relyingParty; return this.delegate.resolve(request, "single");
} }
} }
---- ----
@ -1015,7 +1020,7 @@ You can publish a metadata endpoint by adding the `Saml2MetadataFilter` to the f
[source,java] [source,java]
---- ----
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver = DefaultRelyingPartyRegistrationResolver relyingPartyRegistrationResolver =
new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository); new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository);
Saml2MetadataFilter filter = new Saml2MetadataFilter( Saml2MetadataFilter filter = new Saml2MetadataFilter(
relyingPartyRegistrationResolver, relyingPartyRegistrationResolver,
@ -1035,11 +1040,9 @@ You can change this by calling the `setRequestMatcher` method on the filter:
[source,java] [source,java]
---- ----
filter.setRequestMatcher(new AntPathRequestMatcher("/saml2/metadata/{registrationId}", "GET")); filter.setRequestMatcher(new AntPathRequestMatcher("/saml2/{registrationId}/metadata", "GET"));
---- ----
ensuring that the `registrationId` hint is at the end of the path.
Or, if you have registered a custom relying party registration resolver in the constructor, then you can specify a path without a `registrationId` hint, like so: Or, if you have registered a custom relying party registration resolver in the constructor, then you can specify a path without a `registrationId` hint, like so:
[source,java] [source,java]

View File

@ -42,13 +42,13 @@ import org.springframework.web.util.UriComponentsBuilder;
* @since 5.4 * @since 5.4
*/ */
public final class DefaultRelyingPartyRegistrationResolver public final class DefaultRelyingPartyRegistrationResolver
implements Converter<HttpServletRequest, RelyingPartyRegistration> { implements Converter<HttpServletRequest, RelyingPartyRegistration>, RelyingPartyRegistrationResolver {
private static final char PATH_DELIMITER = '/'; private static final char PATH_DELIMITER = '/';
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private final Converter<HttpServletRequest, String> registrationIdResolver = new RegistrationIdResolver(); private final RequestMatcher registrationRequestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
public DefaultRelyingPartyRegistrationResolver( public DefaultRelyingPartyRegistrationResolver(
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
@ -56,14 +56,28 @@ public final class DefaultRelyingPartyRegistrationResolver
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
} }
/**
* {@inheritDoc}
*/
@Override @Override
public RelyingPartyRegistration convert(HttpServletRequest request) { public RelyingPartyRegistration convert(HttpServletRequest request) {
String registrationId = this.registrationIdResolver.convert(request); return resolve(request, null);
if (registrationId == null) { }
/**
* {@inheritDoc}
*/
@Override
public RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId) {
if (relyingPartyRegistrationId == null) {
relyingPartyRegistrationId = this.registrationRequestMatcher.matcher(request).getVariables()
.get("registrationId");
}
if (relyingPartyRegistrationId == null) {
return null; return null;
} }
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository
.findByRegistrationId(registrationId); .findByRegistrationId(relyingPartyRegistrationId);
if (relyingPartyRegistration == null) { if (relyingPartyRegistration == null) {
return null; return null;
} }
@ -111,16 +125,4 @@ public final class DefaultRelyingPartyRegistrationResolver
return uriComponents.toUriString(); return uriComponents.toUriString();
} }
private static class RegistrationIdResolver implements Converter<HttpServletRequest, String> {
private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
@Override
public String convert(HttpServletRequest request) {
RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
return result.getVariables().get("registrationId");
}
}
} }

View File

@ -0,0 +1,46 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.web;
import javax.servlet.http.HttpServletRequest;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
/**
* A contract for resolving a {@link RelyingPartyRegistration} from the HTTP request
*
* @author Josh Cummings
* @since 5.5
*/
public interface RelyingPartyRegistrationResolver extends Converter<HttpServletRequest, RelyingPartyRegistration> {
@Override
default RelyingPartyRegistration convert(HttpServletRequest request) {
return resolve(request, null);
}
/**
* Resolve a {@link RelyingPartyRegistration} from the HTTP request, using the
* {@code relyingPartyRegistrationId}, if it is provided
* @param request the HTTP request
* @param relyingPartyRegistrationId the {@link RelyingPartyRegistration} identifier
* @return the resolved {@link RelyingPartyRegistration}
*/
RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId);
}

View File

@ -46,7 +46,7 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
public static final String DEFAULT_METADATA_FILE_NAME = "saml-{registrationId}-metadata.xml"; public static final String DEFAULT_METADATA_FILE_NAME = "saml-{registrationId}-metadata.xml";
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter; private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
private final Saml2MetadataResolver saml2MetadataResolver; private final Saml2MetadataResolver saml2MetadataResolver;
@ -55,11 +55,15 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
private RequestMatcher requestMatcher = new AntPathRequestMatcher( private RequestMatcher requestMatcher = new AntPathRequestMatcher(
"/saml2/service-provider-metadata/{registrationId}"); "/saml2/service-provider-metadata/{registrationId}");
public Saml2MetadataFilter( public Saml2MetadataFilter(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver,
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter,
Saml2MetadataResolver saml2MetadataResolver) { Saml2MetadataResolver saml2MetadataResolver) {
this.relyingPartyRegistrationConverter = relyingPartyRegistrationConverter; if (relyingPartyRegistrationResolver instanceof RelyingPartyRegistrationResolver) {
this.relyingPartyRegistrationResolver = (RelyingPartyRegistrationResolver) relyingPartyRegistrationResolver;
}
else {
this.relyingPartyRegistrationResolver = (request, id) -> relyingPartyRegistrationResolver.convert(request);
}
this.saml2MetadataResolver = saml2MetadataResolver; this.saml2MetadataResolver = saml2MetadataResolver;
} }
@ -71,14 +75,15 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
chain.doFilter(request, response); chain.doFilter(request, response);
return; return;
} }
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationConverter.convert(request); String registrationId = matcher.getVariables().get("registrationId");
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
registrationId);
if (relyingPartyRegistration == null) { if (relyingPartyRegistration == null) {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
return; return;
} }
String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration); String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration);
String registrationId = relyingPartyRegistration.getRegistrationId(); writeMetadataToResponse(response, relyingPartyRegistration.getRegistrationId(), metadata);
writeMetadataToResponse(response, registrationId, metadata);
} }
private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata) private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata)

View File

@ -22,14 +22,26 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
public class Saml2WebSsoAuthenticationFilterTests { public class Saml2WebSsoAuthenticationFilterTests {
@ -41,6 +53,8 @@ public class Saml2WebSsoAuthenticationFilterTests {
private HttpServletResponse response = new MockHttpServletResponse(); private HttpServletResponse response = new MockHttpServletResponse();
private AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
@Before @Before
public void setup() { public void setup() {
this.filter = new Saml2WebSsoAuthenticationFilter(this.repository); this.filter = new Saml2WebSsoAuthenticationFilter(this.repository);
@ -84,4 +98,26 @@ public class Saml2WebSsoAuthenticationFilterTests {
.withMessage("No relying party registration found"); .withMessage("No relying party registration found");
} }
@Test
public void doFilterWhenPathStartsWithRegistrationIdThenAuthenticates() throws Exception {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
Authentication authentication = new TestingAuthenticationToken("user", "password");
given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
given(this.authenticationManager.authenticate(authentication)).willReturn(authentication);
String loginProcessingUrl = "/{registrationId}/login/saml2/sso";
RequestMatcher matcher = new AntPathRequestMatcher(loginProcessingUrl);
DefaultRelyingPartyRegistrationResolver delegate = new DefaultRelyingPartyRegistrationResolver(this.repository);
RelyingPartyRegistrationResolver resolver = (request, id) -> {
String registrationId = matcher.matcher(request).getVariables().get("registrationId");
return delegate.resolve(request, registrationId);
};
Saml2AuthenticationTokenConverter authenticationConverter = new Saml2AuthenticationTokenConverter(resolver);
this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, loginProcessingUrl);
this.filter.setAuthenticationManager(this.authenticationManager);
this.request.setPathInfo("/registration-id/login/saml2/sso");
this.request.setParameter("SAMLResponse", "response");
this.filter.doFilter(this.request, this.response, new MockFilterChain());
verify(this.repository).findByRegistrationId("registration-id");
}
} }

View File

@ -36,7 +36,13 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.HtmlUtils;
import org.springframework.web.util.UriUtils; import org.springframework.web.util.UriUtils;
@ -216,4 +222,29 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
assertThat(this.response.getStatus()).isEqualTo(401); assertThat(this.response.getStatus()).isEqualTo(401);
} }
@Test
public void doFilterWhenPathStartsWithRegistrationIdThenPosts() throws Exception {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
.assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)).build();
RequestMatcher matcher = new AntPathRequestMatcher("/{registrationId}/saml2/authenticate");
DefaultRelyingPartyRegistrationResolver delegate = new DefaultRelyingPartyRegistrationResolver(this.repository);
RelyingPartyRegistrationResolver resolver = (request, id) -> {
String registrationId = matcher.matcher(request).getVariables().get("registrationId");
return delegate.resolve(request, registrationId);
};
Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver(
resolver);
Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class);
given(authenticationRequest.getAuthenticationRequestUri()).willReturn("uri");
given(authenticationRequest.getRelayState()).willReturn("relay");
given(authenticationRequest.getSamlRequest()).willReturn("saml");
given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
given(this.factory.createPostAuthenticationRequest(any())).willReturn(authenticationRequest);
this.filter = new Saml2WebSsoAuthenticationRequestFilter(authenticationRequestContextResolver, this.factory);
this.filter.setRedirectMatcher(matcher);
this.request.setPathInfo("/registration-id/saml2/authenticate");
this.filter.doFilter(this.request, this.response, new MockFilterChain());
verify(this.repository).findByRegistrationId("registration-id");
}
} }

View File

@ -25,6 +25,7 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.core.TestSaml2X509Credentials;
@ -37,6 +38,7 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -136,6 +138,20 @@ public class Saml2MetadataFilterTests {
.isEqualTo("attachment; filename=\"%s\"; filename*=UTF-8''%s", fileName, encodedFileName); .isEqualTo("attachment; filename=\"%s\"; filename*=UTF-8''%s", fileName, encodedFileName);
} }
@Test
public void doFilterWhenPathStartsWithRegistrationIdThenServesMetadata() throws Exception {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
given(this.resolver.resolve(any())).willReturn("metadata");
DefaultRelyingPartyRegistrationResolver resolver = new DefaultRelyingPartyRegistrationResolver(
(id) -> this.repository.findByRegistrationId("registration-id"));
this.filter = new Saml2MetadataFilter(resolver, this.resolver);
this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata"));
this.request.setPathInfo("/metadata");
this.filter.doFilter(this.request, this.response, new MockFilterChain());
verify(this.repository).findByRegistrationId("registration-id");
}
@Test @Test
public void setRequestMatcherWhenNullThenIllegalArgument() { public void setRequestMatcherWhenNullThenIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestMatcher(null)); assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestMatcher(null));