From 16e17d242e1f13b2e3c337d7c516114bf0cbbbd3 Mon Sep 17 00:00:00 2001 From: Marcus Da Coregio Date: Mon, 28 Jun 2021 13:58:42 -0300 Subject: [PATCH] Add Saml2AuthenticationRequestRepository Closes gh-9185 --- .../saml2/Saml2LoginConfigurer.java | 29 +++- .../saml2/Saml2LoginConfigurerTests.java | 49 +++++- .../_includes/servlet/saml2/saml2-login.adoc | 30 ++++ ...ing-security-saml2-service-provider.gradle | 1 + .../Saml2AuthenticationToken.java | 45 +++++- ...nSaml2AuthenticationRequestRepository.java | 73 +++++++++ .../Saml2AuthenticationRequestRepository.java | 60 ++++++++ .../Saml2WebSsoAuthenticationFilter.java | 33 +++- ...aml2WebSsoAuthenticationRequestFilter.java | 31 +++- .../Saml2AuthenticationTokenConverter.java | 27 +++- .../TestSaml2AuthenticationTokens.java | 38 +++++ ...2AuthenticationRequestRepositoryTests.java | 143 ++++++++++++++++++ .../Saml2WebSsoAuthenticationFilterTests.java | 51 ++++++- ...ebSsoAuthenticationRequestFilterTests.java | 42 ++++- ...aml2AuthenticationTokenConverterTests.java | 24 +++ 15 files changed, 657 insertions(+), 19 deletions(-) create mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepository.java create mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/Saml2AuthenticationRequestRepository.java create mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationTokens.java create mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepositoryTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index 05aec9cd21..1821690d3a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ import org.springframework.security.config.annotation.web.configurers.AbstractAu import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer; import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; @@ -40,6 +41,8 @@ import org.springframework.security.saml2.provider.service.authentication.OpenSa import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; @@ -206,6 +209,7 @@ public final class Saml2LoginConfigurer> } this.saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(getAuthenticationConverter(http), this.loginProcessingUrl); + setAuthenticationRequestRepository(http, this.saml2WebSsoAuthenticationFilter); setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter); super.loginProcessingUrl(this.loginProcessingUrl); if (StringUtils.hasText(this.loginPage)) { @@ -252,6 +256,11 @@ public final class Saml2LoginConfigurer> } } + private void setAuthenticationRequestRepository(B http, + Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter) { + saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http)); + } + private AuthenticationConverter getAuthenticationConverter(B http) { if (this.authenticationConverter == null) { return new Saml2AuthenticationTokenConverter( @@ -311,6 +320,16 @@ public final class Saml2LoginConfigurer> return idps; } + private Saml2AuthenticationRequestRepository getAuthenticationRequestRepository( + B http) { + Saml2AuthenticationRequestRepository repository = getBeanOrNull(http, + Saml2AuthenticationRequestRepository.class); + if (repository == null) { + return new HttpSessionSaml2AuthenticationRequestRepository(); + } + return repository; + } + private C getSharedOrBean(B http, Class clazz) { C shared = http.getSharedObject(clazz); if (shared != null) { @@ -348,8 +367,12 @@ public final class Saml2LoginConfigurer> private Filter build(B http) { Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http); Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http); - return postProcess( - new Saml2WebSsoAuthenticationRequestFilter(contextResolver, authenticationRequestResolver)); + Saml2AuthenticationRequestRepository repository = getAuthenticationRequestRepository( + http); + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver, + authenticationRequestResolver); + filter.setAuthenticationRequestRepository(repository); + return postProcess(filter); } private Saml2AuthenticationRequestFactory getResolver(B http) { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index 2c184c91b7..dc5b8b45f7 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -63,6 +63,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2Utils; import org.springframework.security.saml2.core.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; @@ -76,9 +77,11 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.context.HttpRequestResponseHolder; @@ -237,6 +240,29 @@ public class Saml2LoginConfigurerTests { assertThat(exception.getCause()).isInstanceOf(IOException.class); } + @Test + public void authenticationRequestWhenCustomAuthnRequestRepositoryThenUses() throws Exception { + this.spring.register(CustomAuthenticationRequestRepository.class).autowire(); + MockHttpServletRequestBuilder request = get("/saml2/authenticate/registration-id"); + this.mvc.perform(request).andExpect(status().isFound()); + Saml2AuthenticationRequestRepository repository = this.spring.getContext() + .getBean(Saml2AuthenticationRequestRepository.class); + verify(repository).saveAuthenticationRequest(any(AbstractSaml2AuthenticationRequest.class), + any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void authenticateWhenCustomAuthnRequestRepositoryThenUses() throws Exception { + this.spring.register(CustomAuthenticationRequestRepository.class).autowire(); + MockHttpServletRequestBuilder request = post("/login/saml2/sso/registration-id").param("SAMLResponse", + SIGNED_RESPONSE); + Saml2AuthenticationRequestRepository repository = this.spring.getContext() + .getBean(Saml2AuthenticationRequestRepository.class); + this.mvc.perform(request); + verify(repository).loadAuthenticationRequest(any(HttpServletRequest.class)); + verify(repository).removeAuthenticationRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + private void validateSaml2WebSsoAuthenticationFilterConfiguration() { // get the OpenSamlAuthenticationProvider Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain); @@ -371,7 +397,7 @@ public class Saml2LoginConfigurerTests { @Bean Saml2AuthenticationRequestContextResolver resolver() { - return resolver; + return this.resolver; } } @@ -420,6 +446,27 @@ public class Saml2LoginConfigurerTests { } + @EnableWebSecurity + @Import(Saml2LoginConfigBeans.class) + static class CustomAuthenticationRequestRepository { + + private static final Saml2AuthenticationRequestRepository repository = mock( + Saml2AuthenticationRequestRepository.class); + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + http.authorizeRequests((authz) -> authz.anyRequest().authenticated()); + http.saml2Login(withDefaults()); + return http.build(); + } + + @Bean + Saml2AuthenticationRequestRepository authenticationRequestRepository() { + return this.repository; + } + + } + static class Saml2LoginConfigBeans { @Bean diff --git a/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc b/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc index 1a1ab2335c..bcab469f59 100644 --- a/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc +++ b/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc @@ -1610,3 +1610,33 @@ http { The success handler will send logout requests to the asserting party. The request matcher will detect logout requests from the asserting party. + +[[servlet-saml2login-store-authn-request]] +=== Storing the `AuthnRequest` + +The `Saml2AuthenticationRequestRepository` is responsible for the persistence of the `AuthnRequest` from the time the `AuthnRequest` <> to the time the `SAMLResponse` <>. +The `Saml2AuthenticationTokenConverter` is responsible for loading the `AuthnRequest` from the `Saml2AuthenticationRequestRepository` and saving it into the `Saml2AuthenticationToken`. + +The default implementation of `Saml2AuthenticationRequestRepository` is `HttpSessionSaml2AuthenticationRequestRepository`, which stores the `AuthnRequest` in the `HttpSession`. + +If you have a custom implementation of `Saml2AuthenticationRequestRepository`, you may configure it by exposing it as a `@Bean` as shown in the following example: + +==== +.Java +[source,java,role="primary"] +---- +@Bean +Saml2AuthenticationRequestRepository authenticationRequestRepository() { + return new CustomSaml2AuthenticationRequestRepository(); +} +---- + +.Kotlin +[source,kotlin,role="secondary"] +---- +@Bean +open fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository { + return CustomSaml2AuthenticationRequestRepository() +} +---- +==== diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index dceb1a355a..97b5547b20 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -55,6 +55,7 @@ dependencies { testImplementation "org.junit.jupiter:junit-jupiter-params" testImplementation "org.junit.jupiter:junit-jupiter-engine" testImplementation "org.mockito:mockito-core" + testImplementation "org.mockito:mockito-inline" testImplementation "org.mockito:mockito-junit-jupiter" testImplementation "org.springframework:spring-test" } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java index 5f4f8fdb33..705c9d2818 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,32 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { private final String saml2Response; + private final AbstractSaml2AuthenticationRequest authenticationRequest; + + /** + * Creates a {@link Saml2AuthenticationToken} with the provided parameters. + * + * Note that the given {@link RelyingPartyRegistration} should have all its templates + * resolved at this point. See + * {@link org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter} + * for an example of performing that resolution. + * @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to + * use + * @param saml2Response the SAML 2.0 response to authenticate + * @param authenticationRequest the {@code AuthNRequest} sent to the asserting party + * + * @since 5.6 + */ + public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response, + AbstractSaml2AuthenticationRequest authenticationRequest) { + super(Collections.emptyList()); + Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null"); + Assert.notNull(saml2Response, "saml2Response cannot be null"); + this.relyingPartyRegistration = relyingPartyRegistration; + this.saml2Response = saml2Response; + this.authenticationRequest = authenticationRequest; + } + /** * Creates a {@link Saml2AuthenticationToken} with the provided parameters * @@ -52,11 +78,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { * @since 5.4 */ public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) { - super(Collections.emptyList()); - Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null"); - Assert.notNull(saml2Response, "saml2Response cannot be null"); - this.relyingPartyRegistration = relyingPartyRegistration; - this.saml2Response = saml2Response; + this(relyingPartyRegistration, saml2Response, null); } /** @@ -81,6 +103,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { .entityId(idpEntityId).singleSignOnServiceLocation(idpEntityId)) .build(); this.saml2Response = saml2Response; + this.authenticationRequest = null; } /** @@ -179,4 +202,14 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId(); } + /** + * Returns the authentication request sent to the assertion party or {@code null} if + * no authentication request is present + * @return the authentication request sent to the assertion party + * @since 5.6 + */ + public AbstractSaml2AuthenticationRequest getAuthenticationRequest() { + return this.authenticationRequest; + } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepository.java new file mode 100644 index 0000000000..eab739ce9c --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepository.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.servlet; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; + +/** + * A {@link Saml2AuthenticationRequestRepository} implementation that uses + * {@link HttpSession} to store and retrieve the + * {@link AbstractSaml2AuthenticationRequest} + * + * @author Marcus Da Coregio + * @since 5.6 + */ +public class HttpSessionSaml2AuthenticationRequestRepository + implements Saml2AuthenticationRequestRepository { + + private static final String DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME = HttpSessionSaml2AuthenticationRequestRepository.class + .getName().concat(".SAML2_AUTHN_REQUEST"); + + private String saml2AuthnRequestAttributeName = DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME; + + @Override + public AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) { + HttpSession httpSession = request.getSession(false); + if (httpSession == null) { + return null; + } + return (AbstractSaml2AuthenticationRequest) httpSession.getAttribute(this.saml2AuthnRequestAttributeName); + } + + @Override + public void saveAuthenticationRequest(AbstractSaml2AuthenticationRequest authenticationRequest, + HttpServletRequest request, HttpServletResponse response) { + if (authenticationRequest == null) { + removeAuthenticationRequest(request, response); + return; + } + HttpSession httpSession = request.getSession(); + httpSession.setAttribute(this.saml2AuthnRequestAttributeName, authenticationRequest); + } + + @Override + public AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request, + HttpServletResponse response) { + AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request); + if (authenticationRequest == null) { + return null; + } + HttpSession httpSession = request.getSession(); + httpSession.removeAttribute(this.saml2AuthnRequestAttributeName); + return authenticationRequest; + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/Saml2AuthenticationRequestRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/Saml2AuthenticationRequestRepository.java new file mode 100644 index 0000000000..379f4c43db --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/Saml2AuthenticationRequestRepository.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.servlet; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; + +/** + * A repository for {@link AbstractSaml2AuthenticationRequest} + * + * @param the type of SAML 2.0 Authentication Request + * @author Marcus Da Coregio + * @since 5.6 + */ +public interface Saml2AuthenticationRequestRepository { + + /** + * Loads the {@link AbstractSaml2AuthenticationRequest} from the request + * @param request the current request + * @return the {@link AbstractSaml2AuthenticationRequest} or {@code null} if it is not + * present + */ + T loadAuthenticationRequest(HttpServletRequest request); + + /** + * Saves the current authentication request using the {@link HttpServletRequest} and + * {@link HttpServletResponse} + * @param authenticationRequest the {@link AbstractSaml2AuthenticationRequest} + * @param request the current request + * @param response the current response + */ + void saveAuthenticationRequest(T authenticationRequest, HttpServletRequest request, HttpServletResponse response); + + /** + * Removes the authentication request using the {@link HttpServletRequest} and + * {@link HttpServletResponse} + * @param request the current request + * @param response the current response + * @return the removed {@link AbstractSaml2AuthenticationRequest} or {@code null} if + * it is not present + */ + T removeAuthenticationRequest(HttpServletRequest request, HttpServletResponse response); + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java index 2c2f833c83..c59ec4deeb 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,8 +23,11 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; @@ -42,6 +45,8 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce private final AuthenticationConverter authenticationConverter; + private Saml2AuthenticationRequestRepository authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository(); + /** * Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is * configured to use the {@link #DEFAULT_FILTER_PROCESSES_URI} processing URL @@ -100,7 +105,33 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce "No relying party registration found"); throw new Saml2AuthenticationException(saml2Error); } + this.authenticationRequestRepository.removeAuthenticationRequest(request, response); return getAuthenticationManager().authenticate(authentication); } + /** + * Use the given {@link Saml2AuthenticationRequestRepository} to remove the saved + * authentication request. If the {@link #authenticationConverter} is of the type + * {@link Saml2AuthenticationTokenConverter}, the + * {@link Saml2AuthenticationRequestRepository} will also be set into the + * {@link #authenticationConverter}. + * @param authenticationRequestRepository the + * {@link Saml2AuthenticationRequestRepository} to use + * @since 5.6 + */ + public void setAuthenticationRequestRepository( + Saml2AuthenticationRequestRepository authenticationRequestRepository) { + Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null"); + this.authenticationRequestRepository = authenticationRequestRepository; + setAuthenticationRequestRepositoryIntoAuthenticationConverter(authenticationRequestRepository); + } + + private void setAuthenticationRequestRepositoryIntoAuthenticationConverter( + Saml2AuthenticationRequestRepository authenticationRequestRepository) { + if (this.authenticationConverter instanceof Saml2AuthenticationTokenConverter) { + Saml2AuthenticationTokenConverter authenticationTokenConverter = (Saml2AuthenticationTokenConverter) this.authenticationConverter; + authenticationTokenConverter.setAuthenticationRequestRepository(authenticationRequestRepository); + } + } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java index b9b1f685f4..39819a513a 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java @@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletResponse; import org.opensaml.core.Version; import org.springframework.http.MediaType; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; @@ -34,6 +35,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; @@ -79,6 +82,8 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); + private Saml2AuthenticationRequestRepository authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository(); + /** * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided * parameters @@ -149,6 +154,19 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter this.redirectMatcher = redirectMatcher; } + /** + * Use the given {@link Saml2AuthenticationRequestRepository} to save the + * authentication request + * @param authenticationRequestRepository the + * {@link Saml2AuthenticationRequestRepository} to use + * @since 5.6 + */ + public void setAuthenticationRequestRepository( + Saml2AuthenticationRequestRepository authenticationRequestRepository) { + Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null"); + this.authenticationRequestRepository = authenticationRequestRepository; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { @@ -165,17 +183,18 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter } RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration(); if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) { - sendRedirect(response, context); + sendRedirect(request, response, context); } else { - sendPost(response, context); + sendPost(request, response, context); } } - private void sendRedirect(HttpServletResponse response, Saml2AuthenticationRequestContext context) - throws IOException { + private void sendRedirect(HttpServletRequest request, HttpServletResponse response, + Saml2AuthenticationRequestContext context) throws IOException { Saml2RedirectAuthenticationRequest authenticationRequest = this.authenticationRequestFactory .createRedirectAuthenticationRequest(context); + this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response); UriComponentsBuilder uriBuilder = UriComponentsBuilder .fromUriString(authenticationRequest.getAuthenticationRequestUri()); addParameter("SAMLRequest", authenticationRequest.getSamlRequest(), uriBuilder); @@ -194,9 +213,11 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter } } - private void sendPost(HttpServletResponse response, Saml2AuthenticationRequestContext context) throws IOException { + private void sendPost(HttpServletRequest request, HttpServletResponse response, + Saml2AuthenticationRequestContext context) throws IOException { Saml2PostAuthenticationRequest authenticationRequest = this.authenticationRequestFactory .createPostAuthenticationRequest(context); + this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response); String html = createSamlPostRequestFormData(authenticationRequest); response.setContentType(MediaType.TEXT_HTML_VALUE); response.getWriter().write(html); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java index 06062543fe..91f8f3e95c 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java @@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.web; import java.io.ByteArrayOutputStream; import java.nio.charset.StandardCharsets; +import java.util.function.Function; import java.util.zip.Inflater; import java.util.zip.InflaterOutputStream; @@ -30,9 +31,12 @@ import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpMethod; import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.util.Assert; @@ -50,6 +54,8 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo private final Converter relyingPartyRegistrationResolver; + private Function loader; + /** * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for * resolving {@link RelyingPartyRegistration}s @@ -60,6 +66,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo Converter relyingPartyRegistrationResolver) { Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; + this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest; } @Override @@ -74,7 +81,25 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo } byte[] b = samlDecode(saml2Response); saml2Response = inflateIfRequired(request, b); - return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response); + AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request); + return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest); + } + + /** + * Use the given {@link Saml2AuthenticationRequestRepository} to load authentication + * request. + * @param authenticationRequestRepository the + * {@link Saml2AuthenticationRequestRepository} to use + * @since 5.6 + */ + public void setAuthenticationRequestRepository( + Saml2AuthenticationRequestRepository authenticationRequestRepository) { + Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null"); + this.loader = authenticationRequestRepository::loadAuthenticationRequest; + } + + private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) { + return this.loader.apply(request); } private String inflateIfRequired(HttpServletRequest request, byte[] b) { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationTokens.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationTokens.java new file mode 100644 index 0000000000..36784023fb --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationTokens.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; + +/** + * Tests instances for {@link Saml2AuthenticationToken} + * + * @author Marcus Da Coregio + */ +public final class TestSaml2AuthenticationTokens { + + private TestSaml2AuthenticationTokens() { + } + + public static Saml2AuthenticationToken token() { + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration() + .build(); + return new Saml2AuthenticationToken(relyingPartyRegistration, "saml2-xml-response-object"); + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepositoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepositoryTests.java new file mode 100644 index 0000000000..9e9463f62b --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepositoryTests.java @@ -0,0 +1,143 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.servlet; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * @author Marcus Da Coregio + */ +public class HttpSessionSaml2AuthenticationRequestRepositoryTests { + + private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO"; + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + private HttpSessionSaml2AuthenticationRequestRepository authenticationRequestRepository; + + @BeforeEach + public void setup() { + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + this.authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository(); + } + + @Test + public void loadAuthenticationRequestWhenInvalidSessionThenNull() { + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository + .loadAuthenticationRequest(this.request); + assertThat(authenticationRequest).isNull(); + } + + @Test + public void loadAuthenticationRequestWhenNoAttributeInSessionThenNull() { + this.request.getSession(); + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository + .loadAuthenticationRequest(this.request); + assertThat(authenticationRequest).isNull(); + } + + @Test + public void loadAuthenticationRequestWhenAttributeInSessionThenReturnsAuthenticationRequest() { + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + given(mockAuthenticationRequest.getAuthenticationRequestUri()).willReturn(IDP_SSO_URL); + this.request.getSession(); + this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request, + this.response); + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository + .loadAuthenticationRequest(this.request); + assertThat(authenticationRequest.getAuthenticationRequestUri()).isEqualTo(IDP_SSO_URL); + } + + @Test + public void saveAuthenticationRequestWhenSessionDontExistsThenCreateAndSave() { + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request, + this.response); + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository + .loadAuthenticationRequest(this.request); + assertThat(authenticationRequest).isNotNull(); + } + + @Test + public void saveAuthenticationRequestWhenSessionExistsThenSave() { + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + this.request.getSession(); + this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request, + this.response); + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository + .loadAuthenticationRequest(this.request); + assertThat(authenticationRequest).isNotNull(); + } + + @Test + public void saveAuthenticationRequestWhenNullAuthenticationRequestThenDontSave() { + this.request.getSession(); + this.authenticationRequestRepository.saveAuthenticationRequest(null, this.request, this.response); + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository + .loadAuthenticationRequest(this.request); + assertThat(authenticationRequest).isNull(); + } + + @Test + public void removeAuthenticationRequestWhenInvalidSessionThenReturnNull() { + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository + .removeAuthenticationRequest(this.request, this.response); + assertThat(authenticationRequest).isNull(); + } + + @Test + public void removeAuthenticationRequestWhenAttributeInSessionThenRemoveAuthenticationRequest() { + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + given(mockAuthenticationRequest.getAuthenticationRequestUri()).willReturn(IDP_SSO_URL); + this.request.getSession(); + this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request, + this.response); + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository + .removeAuthenticationRequest(this.request, this.response); + AbstractSaml2AuthenticationRequest authenticationRequestAfterRemove = this.authenticationRequestRepository + .loadAuthenticationRequest(this.request); + assertThat(authenticationRequest.getAuthenticationRequestUri()).isEqualTo(IDP_SSO_URL); + assertThat(authenticationRequestAfterRemove).isNull(); + } + + @Test + public void removeAuthenticationRequestWhenValidSessionNoAttributeThenReturnsNull() { + MockHttpSession session = mock(MockHttpSession.class); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(session); + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository + .removeAuthenticationRequest(request, this.response); + verify(session).getAttribute(anyString()); + assertThat(authenticationRequest).isNull(); + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java index 9c5a660fe0..ffece463e1 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,12 +24,20 @@ import org.junit.jupiter.api.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; +import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationTokens; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; +import org.springframework.security.web.authentication.AuthenticationConverter; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; public class Saml2WebSsoAuthenticationFilterTests { @@ -84,4 +92,45 @@ public class Saml2WebSsoAuthenticationFilterTests { .withMessage("No relying party registration found"); } + @Test + public void attemptAuthenticationWhenSavedAuthnRequestThenRemovesAuthnRequest() { + Saml2AuthenticationRequestRepository authenticationRequestRepository = mock( + Saml2AuthenticationRequestRepository.class); + AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + given(authenticationConverter.convert(this.request)).willReturn(TestSaml2AuthenticationTokens.token()); + this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, "/some/other/path/{registrationId}"); + this.filter.setAuthenticationManager((authentication) -> null); + this.request.setPathInfo("/some/other/path/idp-registration-id"); + this.filter.setAuthenticationRequestRepository(authenticationRequestRepository); + this.filter.attemptAuthentication(this.request, this.response); + verify(authenticationRequestRepository).removeAuthenticationRequest(this.request, this.response); + } + + @Test + public void setAuthenticationRequestRepositoryWhenNullThenThrowsIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationRequestRepository(null)) + .withMessage("authenticationRequestRepository cannot be null"); + } + + @Test + public void setAuthenticationRequestRepositoryWhenExpectedAuthenticationConverterTypeThenSetLoaderIntoConverter() { + Saml2AuthenticationTokenConverter authenticationConverterMock = mock(Saml2AuthenticationTokenConverter.class); + Saml2AuthenticationRequestRepository authenticationRequestRepository = mock( + Saml2AuthenticationRequestRepository.class); + this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverterMock, + "/some/other/path/{registrationId}"); + this.filter.setAuthenticationRequestRepository(authenticationRequestRepository); + verify(authenticationConverterMock).setAuthenticationRequestRepository(authenticationRequestRepository); + } + + @Test + public void setAuthenticationRequestRepositoryWhenNotExpectedAuthenticationConverterTypeThenDontSet() { + AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + Saml2AuthenticationRequestRepository authenticationRequestRepository = mock( + Saml2AuthenticationRequestRepository.class); + this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, "/some/other/path/{registrationId}"); + this.filter.setAuthenticationRequestRepository(authenticationRequestRepository); + verifyNoInteractions(authenticationConverter); + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java index d1832cf1d9..0eda04f267 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; @@ -36,6 +37,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriUtils; @@ -43,6 +45,7 @@ import org.springframework.web.util.UriUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -60,6 +63,9 @@ public class Saml2WebSsoAuthenticationRequestFilterTests { private Saml2AuthenticationRequestContextResolver resolver = mock(Saml2AuthenticationRequestContextResolver.class); + private Saml2AuthenticationRequestRepository authenticationRequestRepository = mock( + Saml2AuthenticationRequestRepository.class); + private MockHttpServletRequest request; private MockHttpServletResponse response; @@ -79,6 +85,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests { .providerDetails((c) -> c.entityId("idp-entity-id")).providerDetails((c) -> c.webSsoUrl(IDP_SSO_URL)) .assertionConsumerServiceUrlTemplate("template") .credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartyPrivateCredential())); + this.filter.setAuthenticationRequestRepository(this.authenticationRequestRepository); } @Test @@ -216,4 +223,37 @@ public class Saml2WebSsoAuthenticationRequestFilterTests { assertThat(this.response.getStatus()).isEqualTo(401); } + @Test + public void setAuthenticationRequestRepositoryWhenNullThenException() { + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.resolver, + this.factory); + assertThatIllegalArgumentException().isThrownBy(() -> filter.setAuthenticationRequestRepository(null)); + } + + @Test + public void doFilterWhenRedirectThenSaveRedirectRequest() throws ServletException, IOException { + Saml2AuthenticationRequestContext context = authenticationRequestContext().build(); + Saml2RedirectAuthenticationRequest request = redirectAuthenticationRequest(context).build(); + given(this.resolver.resolve(any())).willReturn(context); + given(this.factory.createRedirectAuthenticationRequest(any())).willReturn(request); + this.filter.doFilterInternal(this.request, this.response, this.filterChain); + verify(this.authenticationRequestRepository).saveAuthenticationRequest( + any(Saml2RedirectAuthenticationRequest.class), eq(this.request), eq(this.response)); + } + + @Test + public void doFilterWhenPostThenSaveRedirectRequest() throws ServletException, IOException { + RelyingPartyRegistration registration = this.rpBuilder + .assertingPartyDetails((asserting) -> asserting.singleSignOnServiceBinding(Saml2MessageBinding.POST)) + .build(); + Saml2AuthenticationRequestContext context = authenticationRequestContext() + .relyingPartyRegistration(registration).build(); + Saml2PostAuthenticationRequest request = postAuthenticationRequest(context).build(); + given(this.resolver.resolve(any())).willReturn(context); + given(this.factory.createPostAuthenticationRequest(any())).willReturn(request); + this.filter.doFilterInternal(this.request, this.response, this.filterChain); + verify(this.authenticationRequestRepository).saveAuthenticationRequest( + any(Saml2PostAuthenticationRequest.class), eq(this.request), eq(this.response)); + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java index 628554a461..ecd69f7dea 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java @@ -31,10 +31,12 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2Utils; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository; import org.springframework.util.StreamUtils; import org.springframework.web.util.UriUtils; @@ -43,6 +45,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; @ExtendWith(MockitoExtension.class) public class Saml2AuthenticationTokenConverterTests { @@ -155,6 +158,27 @@ public class Saml2AuthenticationTokenConverterTests { validateSsoCircleXml(token.getSaml2Response()); } + @Test + public void convertWhenSavedAuthenticationRequestThenToken() { + Saml2AuthenticationRequestRepository authenticationRequestRepository = mock( + Saml2AuthenticationRequestRepository.class); + AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( + this.relyingPartyRegistrationResolver); + converter.setAuthenticationRequestRepository(authenticationRequestRepository); + given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .willReturn(this.relyingPartyRegistration); + given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class))) + .willReturn(authenticationRequest); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8))); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.relyingPartyRegistration.getRegistrationId()); + assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest); + } + private void validateSsoCircleXml(String xml) { assertThat(xml).contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"") .contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"")