From 51a0cffd361a72e7e1f9b4e502e6ec821633b2e4 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 17 Apr 2020 16:48:59 -0600 Subject: [PATCH] Post-process AuthenticationRequestFilter Fixes gh-8552 --- .../saml2/Saml2LoginConfigurer.java | 4 +- .../saml2/Saml2LoginConfigurerTests.java | 56 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index 45b917ed6c..73f33203c1 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -323,9 +323,9 @@ public final class Saml2LoginConfigurer> extend private Filter build(B http) { Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http); - return new Saml2WebSsoAuthenticationRequestFilter( + return postProcess(new Saml2WebSsoAuthenticationRequestFilter( Saml2LoginConfigurer.this.relyingPartyRegistrationRepository, - authenticationRequestResolver); + authenticationRequestResolver)); } private Saml2AuthenticationRequestFactory getResolver(B http) { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index 186f9e3b15..de9c20407e 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -23,6 +23,7 @@ import java.util.Base64; import java.util.Collection; import java.util.Collections; import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; import org.junit.After; import org.junit.Assert; @@ -55,9 +56,13 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; +import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; @@ -66,10 +71,15 @@ import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** * Tests for different Java configuration for {@link Saml2LoginConfigurer} @@ -133,6 +143,20 @@ public class Saml2LoginConfigurerTests { validateSaml2WebSsoAuthenticationFilterConfiguration(); } + @Test + public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() throws Exception { + this.spring.register(CustomAuthenticationRequestContextResolver.class).autowire(); + + Saml2AuthenticationRequestContext context = authenticationRequestContext().build(); + Saml2AuthenticationRequestContextResolver resolver = + CustomAuthenticationRequestContextResolver.resolver; + when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class))) + .thenReturn(context); + this.mvc.perform(get("/saml2/authenticate/registration-id")) + .andExpect(status().isFound()); + verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)); + } + private void validateSaml2WebSsoAuthenticationFilterConfiguration() { // get the OpenSamlAuthenticationProvider Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain); @@ -219,6 +243,38 @@ public class Saml2LoginConfigurerTests { } } + @EnableWebSecurity + @Import(Saml2LoginConfigBeans.class) + static class CustomAuthenticationRequestContextResolver extends WebSecurityConfigurerAdapter { + private static final Saml2AuthenticationRequestContextResolver resolver = + mock(Saml2AuthenticationRequestContextResolver.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + ObjectPostProcessor processor + = new ObjectPostProcessor() { + @Override + public O postProcess(O filter) { + filter.setAuthenticationRequestContextResolver(resolver); + return filter; + } + }; + + http + .authorizeRequests(authz -> authz + .anyRequest().authenticated() + ) + .saml2Login(saml2 -> saml2 + .addObjectPostProcessor(processor) + ); + } + + @Bean + Saml2AuthenticationRequestContextResolver resolver() { + return resolver; + } + } + private static AuthenticationManager getAuthenticationManagerMock(String role) { return new AuthenticationManager() {