Add query parameter support for authn requests

Closes gh-15017
This commit is contained in:
Josh Cummings 2024-06-21 19:17:42 -06:00
parent 587aa495f7
commit 796e4d6b6c
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
6 changed files with 237 additions and 17 deletions

View File

@ -16,9 +16,13 @@
package org.springframework.security.config.annotation.web.configurers.saml2;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.security.authentication.AuthenticationManager;
@ -33,6 +37,7 @@ import org.springframework.security.saml2.provider.service.authentication.Abstra
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.OpenSamlAuthenticationTokenConverter;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
@ -50,6 +55,7 @@ import org.springframework.security.web.util.matcher.AndRequestMatcher;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.ParameterRequestMatcher;
import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatchers;
@ -111,7 +117,13 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
private String loginPage;
private String authenticationRequestUri = Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI;
private String authenticationRequestUri = "/saml2/authenticate";
private String[] authenticationRequestParams = { "registrationId={registrationId}" };
private RequestMatcher authenticationRequestMatcher = RequestMatchers.anyOf(
new AntPathRequestMatcher(Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI),
new AntPathQueryRequestMatcher(this.authenticationRequestUri, this.authenticationRequestParams));
private Saml2AuthenticationRequestResolver authenticationRequestResolver;
@ -196,11 +208,31 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
* Request
* @return the {@link Saml2LoginConfigurer} for further configuration
* @since 6.0
* @deprecated Use {@link #authenticationRequestUriQuery} instead
*/
public Saml2LoginConfigurer<B> authenticationRequestUri(String authenticationRequestUri) {
Assert.state(authenticationRequestUri.contains("{registrationId}"),
"authenticationRequestUri must contain {registrationId} path variable");
this.authenticationRequestUri = authenticationRequestUri;
return authenticationRequestUriQuery(authenticationRequestUri);
}
/**
* Customize the URL that the SAML Authentication Request will be sent to. This method
* also supports query parameters like so: <pre>
* authenticationRequestUriQuery("/saml/authenticate?registrationId={registrationId}")
* </pre> {@link RelyingPartyRegistrations}
* @param authenticationRequestUriQuery the URI and query to use for the SAML 2.0
* Authentication Request
* @return the {@link Saml2LoginConfigurer} for further configuration
* @since 6.0
*/
public Saml2LoginConfigurer<B> authenticationRequestUriQuery(String authenticationRequestUriQuery) {
Assert.state(authenticationRequestUriQuery.contains("{registrationId}"),
"authenticationRequestUri must contain {registrationId} path variable or query value");
String[] parts = authenticationRequestUriQuery.split("[?&]");
this.authenticationRequestUri = parts[0];
this.authenticationRequestParams = new String[parts.length - 1];
System.arraycopy(parts, 1, this.authenticationRequestParams, 0, parts.length - 1);
this.authenticationRequestMatcher = new AntPathQueryRequestMatcher(this.authenticationRequestUri,
this.authenticationRequestParams);
return this;
}
@ -255,7 +287,7 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
}
else {
Map<String, String> providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri,
this.relyingPartyRegistrationRepository);
this.authenticationRequestParams, this.relyingPartyRegistrationRepository);
boolean singleProvider = providerUrlMap.size() == 1;
if (singleProvider) {
// Setup auto-redirect to provider login page
@ -336,8 +368,7 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
}
OpenSaml4AuthenticationRequestResolver openSaml4AuthenticationRequestResolver = new OpenSaml4AuthenticationRequestResolver(
relyingPartyRegistrationRepository(http));
openSaml4AuthenticationRequestResolver
.setRequestMatcher(new AntPathRequestMatcher(this.authenticationRequestUri));
openSaml4AuthenticationRequestResolver.setRequestMatcher(this.authenticationRequestMatcher);
return openSaml4AuthenticationRequestResolver;
}
@ -382,20 +413,28 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
return;
}
loginPageGeneratingFilter.setSaml2LoginEnabled(true);
loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(
this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository));
loginPageGeneratingFilter
.setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap(this.authenticationRequestUri,
this.authenticationRequestParams, this.relyingPartyRegistrationRepository));
loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage());
loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl());
}
@SuppressWarnings("unchecked")
private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl,
private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl, String[] authRequestQueryParams,
RelyingPartyRegistrationRepository idpRepo) {
Map<String, String> idps = new LinkedHashMap<>();
if (idpRepo instanceof Iterable) {
Iterable<RelyingPartyRegistration> repo = (Iterable<RelyingPartyRegistration>) idpRepo;
repo.forEach((p) -> idps.put(authRequestPrefixUrl.replace("{registrationId}", p.getRegistrationId()),
p.getRegistrationId()));
StringBuilder authRequestQuery = new StringBuilder("?");
for (String authRequestQueryParam : authRequestQueryParams) {
authRequestQuery.append(authRequestQueryParam + "&");
}
authRequestQuery.deleteCharAt(authRequestQuery.length() - 1);
String authenticationRequestUriQuery = authRequestPrefixUrl + authRequestQuery;
repo.forEach(
(p) -> idps.put(authenticationRequestUriQuery.replace("{registrationId}", p.getRegistrationId()),
p.getRegistrationId()));
}
return idps;
}
@ -437,4 +476,35 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
}
}
static class AntPathQueryRequestMatcher implements RequestMatcher {
private final RequestMatcher matcher;
AntPathQueryRequestMatcher(String path, String... params) {
List<RequestMatcher> matchers = new ArrayList<>();
matchers.add(new AntPathRequestMatcher(path));
for (String param : params) {
String[] parts = param.split("=");
if (parts.length == 1) {
matchers.add(new ParameterRequestMatcher(parts[0]));
}
else {
matchers.add(new ParameterRequestMatcher(parts[0], parts[1]));
}
}
this.matcher = new AndRequestMatcher(matchers);
}
@Override
public boolean matches(HttpServletRequest request) {
return matcher(request).isMatch();
}
@Override
public MatchResult matcher(HttpServletRequest request) {
return this.matcher.matcher(request);
}
}
}

View File

@ -48,6 +48,7 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand
class Saml2Dsl {
var relyingPartyRegistrationRepository: RelyingPartyRegistrationRepository? = null
var loginPage: String? = null
var authenticationRequestUriQuery: String? = null
var authenticationSuccessHandler: AuthenticationSuccessHandler? = null
var authenticationFailureHandler: AuthenticationFailureHandler? = null
var failureUrl: String? = null
@ -88,6 +89,9 @@ class Saml2Dsl {
defaultSuccessUrlOption?.also {
saml2Login.defaultSuccessUrl(defaultSuccessUrlOption!!.first, defaultSuccessUrlOption!!.second)
}
authenticationRequestUriQuery?.also {
saml2Login.authenticationRequestUriQuery(authenticationRequestUriQuery)
}
authenticationSuccessHandler?.also { saml2Login.successHandler(authenticationSuccessHandler) }
authenticationFailureHandler?.also { saml2Login.failureHandler(authenticationFailureHandler) }
authenticationManager?.also { saml2Login.authenticationManager(authenticationManager) }

View File

@ -101,6 +101,7 @@ import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.Matchers.startsWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
@ -113,6 +114,7 @@ import static org.springframework.security.config.annotation.SecurityContextChan
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@ -343,6 +345,19 @@ public class Saml2LoginConfigurerTests {
any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void authenticationRequestWhenCustomAuthenticationRequestPathRepositoryThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestUriQuery.class).autowire();
MockHttpServletRequestBuilder request = get("/custom/auth/sso");
this.mvc.perform(request)
.andExpect(status().isFound())
.andExpect(redirectedUrl("http://localhost/custom/auth/sso?entityId=registration-id"));
request.queryParam("entityId", registration.getRegistrationId());
MvcResult result = this.mvc.perform(request).andExpect(status().isFound()).andReturn();
String redirectedUrl = result.getResponse().getRedirectedUrl();
assertThat(redirectedUrl).startsWith(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation());
}
@Test
public void saml2LoginWhenLoginProcessingUrlWithoutRegistrationIdAndDefaultAuthenticationConverterThenAutowires()
throws Exception {
@ -390,7 +405,7 @@ public class Saml2LoginConfigurerTests {
.andExpect(redirectedUrl("http://localhost/login"));
this.mvc.perform(get("/").accept(MediaType.TEXT_HTML))
.andExpect(status().isFound())
.andExpect(redirectedUrl("http://localhost/saml2/authenticate/registration-id"));
.andExpect(header().string("Location", startsWith("http://localhost/saml2/authenticate")));
}
@Test
@ -669,6 +684,23 @@ public class Saml2LoginConfigurerTests {
}
@Configuration
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestUriQuery {
@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeHttpRequests((authz) -> authz.anyRequest().authenticated())
.saml2Login((saml2) -> saml2.authenticationRequestUriQuery("/custom/auth/sso?entityId={registrationId}"));
// @formatter:on
return http.build();
}
}
@Configuration
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)

View File

@ -43,11 +43,13 @@ import org.springframework.security.saml2.provider.service.registration.TestRely
import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter
import org.springframework.security.web.SecurityFilterChain
import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.MvcResult
import org.springframework.test.web.servlet.get
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders
import org.springframework.test.web.servlet.result.MockMvcResultMatchers
import java.security.cert.Certificate
import java.security.cert.CertificateFactory
import java.util.Base64
import java.util.*
/**
* Tests for [Saml2Dsl]
@ -136,6 +138,23 @@ class Saml2DslTests {
verify(exactly = 1) { Saml2LoginCustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()) }
}
@Test
@Throws(Exception::class)
fun authenticationRequestWhenCustomAuthenticationRequestPathRepositoryThenUses() {
this.spring.register(CustomAuthenticationRequestUriQuery::class.java).autowire()
val registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
val request = MockMvcRequestBuilders.get("/custom/auth/sso")
this.mockMvc.perform(request)
.andExpect(MockMvcResultMatchers.status().isFound())
.andExpect(MockMvcResultMatchers.redirectedUrl("http://localhost/custom/auth/sso?entityId=simplesamlphp"))
request.queryParam("entityId", registration.registrationId)
val result: MvcResult =
this.mockMvc.perform(request).andExpect(MockMvcResultMatchers.status().isFound()).andReturn()
val redirectedUrl = result.response.redirectedUrl
Assertions.assertThat(redirectedUrl)
.startsWith(registration.assertingPartyDetails.singleSignOnServiceLocation)
}
@Configuration
@EnableWebSecurity
open class Saml2LoginCustomAuthenticationManagerConfig {
@ -162,4 +181,26 @@ class Saml2DslTests {
return repository
}
}
@Configuration
@EnableWebSecurity
open class CustomAuthenticationRequestUriQuery {
@Bean
open fun securityFilterChain(http: HttpSecurity): SecurityFilterChain {
http {
authorizeHttpRequests {
authorize(anyRequest, authenticated)
}
saml2Login {
authenticationRequestUriQuery = "/custom/auth/sso?entityId={registrationId}"
}
}
return http.build()
}
@Bean
open fun relyingPartyRegistrationRepository(): RelyingPartyRegistrationRepository? {
return InMemoryRelyingPartyRegistrationRepository(TestRelyingPartyRegistrations.relyingPartyRegistration().build())
}
}
}

View File

@ -4,7 +4,7 @@
As stated earlier, Spring Security's SAML 2.0 support produces a `<saml2:AuthnRequest>` to commence authentication with the asserting party.
Spring Security achieves this in part by registering the `Saml2WebSsoAuthenticationRequestFilter` in the filter chain.
This filter by default responds to endpoint `+/saml2/authenticate/{registrationId}+`.
This filter by default responds to the endpoints `+/saml2/authenticate/{registrationId}+` and `+/saml2/authenticate?registrationId={registrationId}+`.
For example, if you were deployed to `https://rp.example.com` and you gave your registration an ID of `okta`, you could navigate to:
@ -12,6 +12,42 @@ For example, if you were deployed to `https://rp.example.com` and you gave your
and the result would be a redirect that included a `SAMLRequest` parameter containing the signed, deflated, and encoded `<saml2:AuthnRequest>`.
== Configuring the `<saml2:AuthnRequest>` Endpoint
To configure the endpoint differently from the default, you can set the value in `saml2Login`:
[tabs]
======
Java::
+
[source,java,role="primary"]
----
@Bean
SecurityFilterChain filterChain(HttpSecurity http) {
http
.saml2Login((saml2) -> saml2
.authenticationRequestUriQuery("/custom/auth/sso?peerEntityID={registrationId}")
);
return new CustomSaml2AuthenticationRequestRepository();
}
----
Kotlin::
+
[source,kotlin,role="secondary"]
----
@Bean
fun filterChain(http: HttpSecurity): SecurityFilterChain {
http {
saml2Login {
authenticationRequestUriQuery = "/custom/auth/sso?peerEntityID={registrationId}"
}
}
return CustomSaml2AuthenticationRequestRepository()
}
----
======
[[servlet-saml2login-store-authn-request]]
== Changing How the `<saml2:AuthnRequest>` Gets Stored

View File

@ -17,6 +17,8 @@
package org.springframework.security.saml2.provider.service.web.authentication;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.BiConsumer;
@ -50,8 +52,11 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.web.util.matcher.AndRequestMatcher;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.ParameterRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatchers;
import org.springframework.util.Assert;
/**
@ -75,8 +80,9 @@ class OpenSamlAuthenticationRequestResolver {
private final NameIDPolicyBuilder nameIdPolicyBuilder;
private RequestMatcher requestMatcher = new AntPathRequestMatcher(
Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI);
private RequestMatcher requestMatcher = RequestMatchers.anyOf(
new AntPathRequestMatcher(Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI),
new AntPathQueryRequestMatcher("/saml2/authenticate", "registrationId={registrationId}"));
private Converter<HttpServletRequest, String> relayStateResolver = (request) -> UUID.randomUUID().toString();
@ -199,4 +205,35 @@ class OpenSamlAuthenticationRequestResolver {
}
}
private static final class AntPathQueryRequestMatcher implements RequestMatcher {
private final RequestMatcher matcher;
AntPathQueryRequestMatcher(String path, String... params) {
List<RequestMatcher> matchers = new ArrayList<>();
matchers.add(new AntPathRequestMatcher(path));
for (String param : params) {
String[] parts = param.split("=");
if (parts.length == 1) {
matchers.add(new ParameterRequestMatcher(parts[0]));
}
else {
matchers.add(new ParameterRequestMatcher(parts[0], parts[1]));
}
}
this.matcher = new AndRequestMatcher(matchers);
}
@Override
public boolean matches(HttpServletRequest request) {
return matcher(request).isMatch();
}
@Override
public MatchResult matcher(HttpServletRequest request) {
return this.matcher.matcher(request);
}
}
}