Polish DefaultSaml2AuthenticationRequestContextResolver
Issue gh-8360 Issue gh-8887
This commit is contained in:
parent
015281ff53
commit
a10c2c6cf8
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -35,6 +35,9 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
|
|||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
||||
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
|
||||
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
|
||||
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
|
||||
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
|
||||
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
|
||||
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
|
||||
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
|
||||
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||
|
@ -317,15 +320,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
|
|||
|
||||
private final class AuthenticationRequestEndpointConfig {
|
||||
private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";
|
||||
|
||||
private AuthenticationRequestEndpointConfig() {
|
||||
}
|
||||
|
||||
private Filter build(B http) {
|
||||
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
|
||||
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
|
||||
|
||||
return postProcess(new Saml2WebSsoAuthenticationRequestFilter(
|
||||
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository,
|
||||
authenticationRequestResolver));
|
||||
contextResolver, authenticationRequestResolver));
|
||||
}
|
||||
|
||||
private Saml2AuthenticationRequestFactory getResolver(B http) {
|
||||
|
@ -335,6 +339,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
|
|||
}
|
||||
return resolver;
|
||||
}
|
||||
|
||||
private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
|
||||
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class);
|
||||
if (resolver == null) {
|
||||
return new DefaultSaml2AuthenticationRequestContextResolver(
|
||||
new DefaultRelyingPartyRegistrationResolver(
|
||||
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository));
|
||||
}
|
||||
return resolver;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -65,10 +65,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
|
|||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
||||
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
|
||||
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
|
||||
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
|
||||
import org.springframework.security.web.FilterChainProxy;
|
||||
import org.springframework.security.web.context.HttpRequestResponseHolder;
|
||||
|
@ -87,6 +85,7 @@ import static org.mockito.ArgumentMatchers.anyString;
|
|||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.security.config.Customizer.withDefaults;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
|
||||
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||
|
@ -161,11 +160,11 @@ public class Saml2LoginConfigurerTests {
|
|||
Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
|
||||
Saml2AuthenticationRequestContextResolver resolver =
|
||||
CustomAuthenticationRequestContextResolver.resolver;
|
||||
when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)))
|
||||
when(resolver.resolve(any(HttpServletRequest.class)))
|
||||
.thenReturn(context);
|
||||
this.mvc.perform(get("/saml2/authenticate/registration-id"))
|
||||
.andExpect(status().isFound());
|
||||
verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
|
||||
verify(resolver).resolve(any(HttpServletRequest.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -276,22 +275,11 @@ public class Saml2LoginConfigurerTests {
|
|||
|
||||
@Override
|
||||
protected void configure(HttpSecurity http) throws Exception {
|
||||
ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter> processor
|
||||
= new ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter>() {
|
||||
@Override
|
||||
public <O extends Saml2WebSsoAuthenticationRequestFilter> O postProcess(O filter) {
|
||||
filter.setAuthenticationRequestContextResolver(resolver);
|
||||
return filter;
|
||||
}
|
||||
};
|
||||
|
||||
http
|
||||
.authorizeRequests(authz -> authz
|
||||
.anyRequest().authenticated()
|
||||
)
|
||||
.saml2Login(saml2 -> saml2
|
||||
.addObjectPostProcessor(processor)
|
||||
);
|
||||
.saml2Login(withDefaults());
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
|
|||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
||||
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
|
||||
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
|
||||
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
|
||||
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
|
||||
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||
|
@ -69,9 +70,8 @@ import static java.nio.charset.StandardCharsets.ISO_8859_1;
|
|||
*/
|
||||
public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter {
|
||||
|
||||
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
|
||||
private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver;
|
||||
private Saml2AuthenticationRequestFactory authenticationRequestFactory;
|
||||
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();
|
||||
|
||||
private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
|
||||
|
||||
|
@ -83,21 +83,24 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
|
|||
*/
|
||||
@Deprecated
|
||||
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
|
||||
this(relyingPartyRegistrationRepository,
|
||||
this(new DefaultSaml2AuthenticationRequestContextResolver(
|
||||
new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
|
||||
new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory());
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters
|
||||
*
|
||||
* @param relyingPartyRegistrationRepository a repository for relying party configurations
|
||||
* @param authenticationRequestContextResolver a strategy for formulating a {@link Saml2AuthenticationRequestContext}
|
||||
* @since 5.4
|
||||
*/
|
||||
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
|
||||
public Saml2WebSsoAuthenticationRequestFilter(
|
||||
Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver,
|
||||
Saml2AuthenticationRequestFactory authenticationRequestFactory) {
|
||||
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
|
||||
|
||||
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
|
||||
Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
|
||||
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
|
||||
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
|
||||
this.authenticationRequestFactory = authenticationRequestFactory;
|
||||
}
|
||||
|
||||
|
@ -123,17 +126,6 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
|
|||
this.redirectMatcher = redirectMatcher;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext}
|
||||
*
|
||||
* @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use
|
||||
* @since 5.4
|
||||
*/
|
||||
public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) {
|
||||
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
|
||||
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
|
@ -147,14 +139,12 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
|
|||
return;
|
||||
}
|
||||
|
||||
String registrationId = matcher.getVariables().get("registrationId");
|
||||
RelyingPartyRegistration relyingParty =
|
||||
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
|
||||
if (relyingParty == null) {
|
||||
Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request);
|
||||
if (context == null) {
|
||||
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
|
||||
return;
|
||||
}
|
||||
Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty);
|
||||
RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
|
||||
if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
|
||||
sendRedirect(response, context);
|
||||
} else {
|
||||
|
|
|
@ -16,45 +16,45 @@
|
|||
|
||||
package org.springframework.security.saml2.provider.service.web;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.springframework.core.convert.converter.Converter;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.util.UriComponents;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
|
||||
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
|
||||
|
||||
/**
|
||||
* The default implementation for {@link Saml2AuthenticationRequestContextResolver}
|
||||
* which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext}
|
||||
*
|
||||
* @author Shazin Sadakath
|
||||
* @author Josh Cummings
|
||||
* @since 5.4
|
||||
*/
|
||||
public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver {
|
||||
|
||||
private final Log logger = LogFactory.getLog(getClass());
|
||||
|
||||
private static final char PATH_DELIMITER = '/';
|
||||
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
|
||||
|
||||
public DefaultSaml2AuthenticationRequestContextResolver
|
||||
(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
|
||||
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
|
||||
RelyingPartyRegistration relyingParty) {
|
||||
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) {
|
||||
Assert.notNull(request, "request cannot be null");
|
||||
Assert.notNull(relyingParty, "relyingParty cannot be null");
|
||||
RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationResolver.convert(request);
|
||||
if (relyingParty == null) {
|
||||
return null;
|
||||
}
|
||||
if (this.logger.isDebugEnabled()) {
|
||||
this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" +
|
||||
relyingParty.getRegistrationId() + "]");
|
||||
|
@ -65,59 +65,11 @@ public final class DefaultSaml2AuthenticationRequestContextResolver implements S
|
|||
private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(
|
||||
HttpServletRequest request, RelyingPartyRegistration relyingParty) {
|
||||
|
||||
String applicationUri = getApplicationUri(request);
|
||||
Function<String, String> resolver = templateResolver(applicationUri, relyingParty);
|
||||
String localSpEntityId = resolver.apply(relyingParty.getEntityId());
|
||||
String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceLocation());
|
||||
return Saml2AuthenticationRequestContext.builder()
|
||||
.issuer(localSpEntityId)
|
||||
.issuer(relyingParty.getEntityId())
|
||||
.relyingPartyRegistration(relyingParty)
|
||||
.assertionConsumerServiceUrl(assertionConsumerServiceUrl)
|
||||
.assertionConsumerServiceUrl(relyingParty.getAssertionConsumerServiceLocation())
|
||||
.relayState(request.getParameter("RelayState"))
|
||||
.build();
|
||||
}
|
||||
|
||||
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
|
||||
return template -> resolveUrlTemplate(template, applicationUri, relyingParty);
|
||||
}
|
||||
|
||||
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
|
||||
String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
|
||||
String registrationId = relyingParty.getRegistrationId();
|
||||
Map<String, String> uriVariables = new HashMap<>();
|
||||
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl)
|
||||
.replaceQuery(null)
|
||||
.fragment(null)
|
||||
.build();
|
||||
String scheme = uriComponents.getScheme();
|
||||
uriVariables.put("baseScheme", scheme == null ? "" : scheme);
|
||||
String host = uriComponents.getHost();
|
||||
uriVariables.put("baseHost", host == null ? "" : host);
|
||||
// following logic is based on HierarchicalUriComponents#toUriString()
|
||||
int port = uriComponents.getPort();
|
||||
uriVariables.put("basePort", port == -1 ? "" : ":" + port);
|
||||
String path = uriComponents.getPath();
|
||||
if (StringUtils.hasLength(path)) {
|
||||
if (path.charAt(0) != PATH_DELIMITER) {
|
||||
path = PATH_DELIMITER + path;
|
||||
}
|
||||
}
|
||||
uriVariables.put("basePath", path == null ? "" : path);
|
||||
uriVariables.put("baseUrl", uriComponents.toUriString());
|
||||
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
|
||||
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
|
||||
|
||||
return UriComponentsBuilder.fromUriString(template)
|
||||
.buildAndExpand(uriVariables)
|
||||
.toUriString();
|
||||
}
|
||||
|
||||
private static String getApplicationUri(HttpServletRequest request) {
|
||||
UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request))
|
||||
.replacePath(request.getContextPath())
|
||||
.replaceQuery(null)
|
||||
.fragment(null)
|
||||
.build();
|
||||
return uriComponents.toUriString();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,16 +16,16 @@
|
|||
|
||||
package org.springframework.security.saml2.provider.service.web;
|
||||
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
||||
|
||||
/**
|
||||
* This {@code Saml2AuthenticationRequestContextResolver} formulates a
|
||||
* <a href="https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf">SAML 2.0 AuthnRequest</a> (line 1968)
|
||||
*
|
||||
* @author Shazin Sadakath
|
||||
* @author Josh Cummings
|
||||
* @since 5.4
|
||||
*/
|
||||
public interface Saml2AuthenticationRequestContextResolver {
|
||||
|
@ -35,9 +35,7 @@ public interface Saml2AuthenticationRequestContextResolver {
|
|||
*
|
||||
*
|
||||
* @param request the current request
|
||||
* @param relyingParty the relying party responsible for saml2 sso authentication
|
||||
* @return the created {@link Saml2AuthenticationRequestContext} for request/relying party combination
|
||||
* @return the created {@link Saml2AuthenticationRequestContext} for the request
|
||||
*/
|
||||
Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
|
||||
RelyingPartyRegistration relyingParty);
|
||||
Saml2AuthenticationRequestContext resolve(HttpServletRequest request);
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
|
|||
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
||||
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
|
||||
import org.springframework.web.util.HtmlUtils;
|
||||
import org.springframework.web.util.UriUtils;
|
||||
|
||||
|
@ -41,6 +42,7 @@ import static org.mockito.Mockito.verify;
|
|||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
|
||||
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
|
||||
|
||||
public class Saml2WebSsoAuthenticationRequestFilterTests {
|
||||
|
@ -49,6 +51,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
|
|||
private Saml2WebSsoAuthenticationRequestFilter filter;
|
||||
private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
|
||||
private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class);
|
||||
private Saml2AuthenticationRequestContextResolver resolver =
|
||||
mock(Saml2AuthenticationRequestContextResolver.class);
|
||||
private MockHttpServletRequest request;
|
||||
private MockHttpServletResponse response;
|
||||
private MockFilterChain filterChain;
|
||||
|
@ -188,12 +192,14 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
|
|||
when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
|
||||
when(authenticationRequest.getRelayState()).thenReturn("relay");
|
||||
when(authenticationRequest.getSamlRequest()).thenReturn("saml");
|
||||
when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty);
|
||||
when(this.resolver.resolve(this.request)).thenReturn(authenticationRequestContext()
|
||||
.relyingPartyRegistration(relyingParty)
|
||||
.build());
|
||||
when(this.factory.createPostAuthenticationRequest(any()))
|
||||
.thenReturn(authenticationRequest);
|
||||
|
||||
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
|
||||
(this.repository, this.factory);
|
||||
(this.resolver, this.factory);
|
||||
filter.doFilterInternal(this.request, this.response, this.filterChain);
|
||||
assertThat(this.response.getContentAsString())
|
||||
.contains("<form action=\"uri\" method=\"post\">")
|
||||
|
|
|
@ -44,11 +44,13 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
|
|||
private MockHttpServletRequest request;
|
||||
private RelyingPartyRegistration.Builder relyingPartyBuilder;
|
||||
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver
|
||||
= new DefaultSaml2AuthenticationRequestContextResolver();
|
||||
= new DefaultSaml2AuthenticationRequestContextResolver(
|
||||
new DefaultRelyingPartyRegistrationResolver(id -> relyingPartyBuilder.build()));
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
this.request = new MockHttpServletRequest();
|
||||
this.request.setPathInfo("/saml2/authenticate/registration-id");
|
||||
this.relyingPartyBuilder = RelyingPartyRegistration
|
||||
.withRegistrationId(REGISTRATION_ID)
|
||||
.localEntityIdTemplate(RELYING_PARTY_ENTITY_ID)
|
||||
|
@ -61,52 +63,43 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
|
|||
@Test
|
||||
public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
|
||||
this.request.addParameter("RelayState", "relay-state");
|
||||
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder.build();
|
||||
Saml2AuthenticationRequestContext context =
|
||||
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
|
||||
this.authenticationRequestContextResolver.resolve(this.request);
|
||||
|
||||
assertThat(context).isNotNull();
|
||||
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL);
|
||||
assertThat(context.getRelayState()).isEqualTo("relay-state");
|
||||
assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL);
|
||||
assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID);
|
||||
assertThat(context.getRelyingPartyRegistration()).isSameAs(relyingParty);
|
||||
assertThat(context.getRelyingPartyRegistration().getRegistrationId())
|
||||
.isSameAs(this.relyingPartyBuilder.build().getRegistrationId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() {
|
||||
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
|
||||
.assertionConsumerServiceUrlTemplate("/saml2/authenticate/{registrationId}")
|
||||
.build();
|
||||
this.relyingPartyBuilder
|
||||
.assertionConsumerServiceLocation("/saml2/authenticate/{registrationId}");
|
||||
Saml2AuthenticationRequestContext context =
|
||||
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
|
||||
this.authenticationRequestContextResolver.resolve(this.request);
|
||||
|
||||
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() {
|
||||
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
|
||||
.assertionConsumerServiceUrlTemplate("{baseUrl}/saml2/authenticate/{registrationId}")
|
||||
.build();
|
||||
this.relyingPartyBuilder
|
||||
.assertionConsumerServiceLocation("{baseUrl}/saml2/authenticate/{registrationId}");
|
||||
Saml2AuthenticationRequestContext context =
|
||||
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
|
||||
this.authenticationRequestContextResolver.resolve(this.request);
|
||||
|
||||
assertThat(context.getAssertionConsumerServiceUrl())
|
||||
.isEqualTo("http://localhost/saml2/authenticate/registration-id");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resolveWhenRequestNullThenException() {
|
||||
assertThatCode(() ->
|
||||
this.authenticationRequestContextResolver.resolve(this.request, null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resolveWhenRelyingPartyNullThenException() {
|
||||
assertThatCode(() ->
|
||||
this.authenticationRequestContextResolver.resolve(null, this.relyingPartyBuilder.build()))
|
||||
this.authenticationRequestContextResolver.resolve(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue