Polish DefaultSaml2AuthenticationRequestContextResolver

Issue gh-8360
Issue gh-8887
This commit is contained in:
Josh Cummings 2020-07-28 17:19:48 -06:00
parent 015281ff53
commit a10c2c6cf8
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
7 changed files with 75 additions and 134 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,6 +35,9 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@ -317,15 +320,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
private final class AuthenticationRequestEndpointConfig {
private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";
private AuthenticationRequestEndpointConfig() {
}
private Filter build(B http) {
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
return postProcess(new Saml2WebSsoAuthenticationRequestFilter(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository,
authenticationRequestResolver));
contextResolver, authenticationRequestResolver));
}
private Saml2AuthenticationRequestFactory getResolver(B http) {
@ -335,6 +339,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
}
return resolver;
}
private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class);
if (resolver == null) {
return new DefaultSaml2AuthenticationRequestContextResolver(
new DefaultRelyingPartyRegistrationResolver(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository));
}
return resolver;
}
}
}

View File

@ -65,10 +65,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.context.HttpRequestResponseHolder;
@ -87,6 +85,7 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@ -161,11 +160,11 @@ public class Saml2LoginConfigurerTests {
Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
Saml2AuthenticationRequestContextResolver resolver =
CustomAuthenticationRequestContextResolver.resolver;
when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)))
when(resolver.resolve(any(HttpServletRequest.class)))
.thenReturn(context);
this.mvc.perform(get("/saml2/authenticate/registration-id"))
.andExpect(status().isFound());
verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
verify(resolver).resolve(any(HttpServletRequest.class));
}
@Test
@ -276,22 +275,11 @@ public class Saml2LoginConfigurerTests {
@Override
protected void configure(HttpSecurity http) throws Exception {
ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter> processor
= new ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter>() {
@Override
public <O extends Saml2WebSsoAuthenticationRequestFilter> O postProcess(O filter) {
filter.setAuthenticationRequestContextResolver(resolver);
return filter;
}
};
http
.authorizeRequests(authz -> authz
.anyRequest().authenticated()
)
.saml2Login(saml2 -> saml2
.addObjectPostProcessor(processor)
);
.saml2Login(withDefaults());
}
@Bean

View File

@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@ -69,9 +70,8 @@ import static java.nio.charset.StandardCharsets.ISO_8859_1;
*/
public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter {
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver;
private Saml2AuthenticationRequestFactory authenticationRequestFactory;
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();
private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
@ -83,21 +83,24 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
*/
@Deprecated
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
this(relyingPartyRegistrationRepository,
this(new DefaultSaml2AuthenticationRequestContextResolver(
new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory());
}
/**
* Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters
*
* @param relyingPartyRegistrationRepository a repository for relying party configurations
* @param authenticationRequestContextResolver a strategy for formulating a {@link Saml2AuthenticationRequestContext}
* @since 5.4
*/
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
public Saml2WebSsoAuthenticationRequestFilter(
Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver,
Saml2AuthenticationRequestFactory authenticationRequestFactory) {
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
this.authenticationRequestFactory = authenticationRequestFactory;
}
@ -123,17 +126,6 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
this.redirectMatcher = redirectMatcher;
}
/**
* Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext}
*
* @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use
* @since 5.4
*/
public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) {
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
}
/**
* {@inheritDoc}
*/
@ -147,14 +139,12 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
return;
}
String registrationId = matcher.getVariables().get("registrationId");
RelyingPartyRegistration relyingParty =
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
if (relyingParty == null) {
Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request);
if (context == null) {
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
return;
}
Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty);
RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
sendRedirect(response, context);
} else {

View File

@ -16,45 +16,45 @@
package org.springframework.security.saml2.provider.service.web;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
/**
* The default implementation for {@link Saml2AuthenticationRequestContextResolver}
* which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext}
*
* @author Shazin Sadakath
* @author Josh Cummings
* @since 5.4
*/
public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver {
private final Log logger = LogFactory.getLog(getClass());
private static final char PATH_DELIMITER = '/';
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
public DefaultSaml2AuthenticationRequestContextResolver
(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
}
/**
* {@inheritDoc}
*/
@Override
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
RelyingPartyRegistration relyingParty) {
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
Assert.notNull(relyingParty, "relyingParty cannot be null");
RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationResolver.convert(request);
if (relyingParty == null) {
return null;
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" +
relyingParty.getRegistrationId() + "]");
@ -65,59 +65,11 @@ public final class DefaultSaml2AuthenticationRequestContextResolver implements S
private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(
HttpServletRequest request, RelyingPartyRegistration relyingParty) {
String applicationUri = getApplicationUri(request);
Function<String, String> resolver = templateResolver(applicationUri, relyingParty);
String localSpEntityId = resolver.apply(relyingParty.getEntityId());
String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceLocation());
return Saml2AuthenticationRequestContext.builder()
.issuer(localSpEntityId)
.issuer(relyingParty.getEntityId())
.relyingPartyRegistration(relyingParty)
.assertionConsumerServiceUrl(assertionConsumerServiceUrl)
.assertionConsumerServiceUrl(relyingParty.getAssertionConsumerServiceLocation())
.relayState(request.getParameter("RelayState"))
.build();
}
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
return template -> resolveUrlTemplate(template, applicationUri, relyingParty);
}
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
String registrationId = relyingParty.getRegistrationId();
Map<String, String> uriVariables = new HashMap<>();
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl)
.replaceQuery(null)
.fragment(null)
.build();
String scheme = uriComponents.getScheme();
uriVariables.put("baseScheme", scheme == null ? "" : scheme);
String host = uriComponents.getHost();
uriVariables.put("baseHost", host == null ? "" : host);
// following logic is based on HierarchicalUriComponents#toUriString()
int port = uriComponents.getPort();
uriVariables.put("basePort", port == -1 ? "" : ":" + port);
String path = uriComponents.getPath();
if (StringUtils.hasLength(path)) {
if (path.charAt(0) != PATH_DELIMITER) {
path = PATH_DELIMITER + path;
}
}
uriVariables.put("basePath", path == null ? "" : path);
uriVariables.put("baseUrl", uriComponents.toUriString());
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
return UriComponentsBuilder.fromUriString(template)
.buildAndExpand(uriVariables)
.toUriString();
}
private static String getApplicationUri(HttpServletRequest request) {
UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request))
.replacePath(request.getContextPath())
.replaceQuery(null)
.fragment(null)
.build();
return uriComponents.toUriString();
}
}

View File

@ -16,16 +16,16 @@
package org.springframework.security.saml2.provider.service.web;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import javax.servlet.http.HttpServletRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
/**
* This {@code Saml2AuthenticationRequestContextResolver} formulates a
* <a href="https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf">SAML 2.0 AuthnRequest</a> (line 1968)
*
* @author Shazin Sadakath
* @author Josh Cummings
* @since 5.4
*/
public interface Saml2AuthenticationRequestContextResolver {
@ -35,9 +35,7 @@ public interface Saml2AuthenticationRequestContextResolver {
*
*
* @param request the current request
* @param relyingParty the relying party responsible for saml2 sso authentication
* @return the created {@link Saml2AuthenticationRequestContext} for request/relying party combination
* @return the created {@link Saml2AuthenticationRequestContext} for the request
*/
Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
RelyingPartyRegistration relyingParty);
Saml2AuthenticationRequestContext resolve(HttpServletRequest request);
}

View File

@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.web.util.HtmlUtils;
import org.springframework.web.util.UriUtils;
@ -41,6 +42,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
public class Saml2WebSsoAuthenticationRequestFilterTests {
@ -49,6 +51,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
private Saml2WebSsoAuthenticationRequestFilter filter;
private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class);
private Saml2AuthenticationRequestContextResolver resolver =
mock(Saml2AuthenticationRequestContextResolver.class);
private MockHttpServletRequest request;
private MockHttpServletResponse response;
private MockFilterChain filterChain;
@ -188,12 +192,14 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
when(authenticationRequest.getRelayState()).thenReturn("relay");
when(authenticationRequest.getSamlRequest()).thenReturn("saml");
when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty);
when(this.resolver.resolve(this.request)).thenReturn(authenticationRequestContext()
.relyingPartyRegistration(relyingParty)
.build());
when(this.factory.createPostAuthenticationRequest(any()))
.thenReturn(authenticationRequest);
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
(this.repository, this.factory);
(this.resolver, this.factory);
filter.doFilterInternal(this.request, this.response, this.filterChain);
assertThat(this.response.getContentAsString())
.contains("<form action=\"uri\" method=\"post\">")

View File

@ -44,11 +44,13 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
private MockHttpServletRequest request;
private RelyingPartyRegistration.Builder relyingPartyBuilder;
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver
= new DefaultSaml2AuthenticationRequestContextResolver();
= new DefaultSaml2AuthenticationRequestContextResolver(
new DefaultRelyingPartyRegistrationResolver(id -> relyingPartyBuilder.build()));
@Before
public void setup() {
this.request = new MockHttpServletRequest();
this.request.setPathInfo("/saml2/authenticate/registration-id");
this.relyingPartyBuilder = RelyingPartyRegistration
.withRegistrationId(REGISTRATION_ID)
.localEntityIdTemplate(RELYING_PARTY_ENTITY_ID)
@ -61,52 +63,43 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
@Test
public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
this.request.addParameter("RelayState", "relay-state");
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder.build();
Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
this.authenticationRequestContextResolver.resolve(this.request);
assertThat(context).isNotNull();
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL);
assertThat(context.getRelayState()).isEqualTo("relay-state");
assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL);
assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID);
assertThat(context.getRelyingPartyRegistration()).isSameAs(relyingParty);
assertThat(context.getRelyingPartyRegistration().getRegistrationId())
.isSameAs(this.relyingPartyBuilder.build().getRegistrationId());
}
@Test
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() {
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
.assertionConsumerServiceUrlTemplate("/saml2/authenticate/{registrationId}")
.build();
this.relyingPartyBuilder
.assertionConsumerServiceLocation("/saml2/authenticate/{registrationId}");
Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
this.authenticationRequestContextResolver.resolve(this.request);
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id");
}
@Test
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() {
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
.assertionConsumerServiceUrlTemplate("{baseUrl}/saml2/authenticate/{registrationId}")
.build();
this.relyingPartyBuilder
.assertionConsumerServiceLocation("{baseUrl}/saml2/authenticate/{registrationId}");
Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
this.authenticationRequestContextResolver.resolve(this.request);
assertThat(context.getAssertionConsumerServiceUrl())
.isEqualTo("http://localhost/saml2/authenticate/registration-id");
}
@Test
public void resolveWhenRequestNullThenException() {
assertThatCode(() ->
this.authenticationRequestContextResolver.resolve(this.request, null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void resolveWhenRelyingPartyNullThenException() {
assertThatCode(() ->
this.authenticationRequestContextResolver.resolve(null, this.relyingPartyBuilder.build()))
this.authenticationRequestContextResolver.resolve(null))
.isInstanceOf(IllegalArgumentException.class);
}
}