diff --git a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java index f67c700412..7eb17c0384 100644 --- a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java +++ b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java @@ -239,7 +239,7 @@ final class AuthenticationConfigBuilder { createX509Filter(authenticationManager, authenticationFilterSecurityContextHolderStrategyRef); createJeeFilter(authenticationManager, authenticationFilterSecurityContextHolderStrategyRef); createLogoutFilter(authenticationFilterSecurityContextHolderStrategyRef); - createSaml2LogoutFilter(); + createSaml2LogoutFilter(authenticationFilterSecurityContextHolderStrategyRef); createLoginPageFilterIfNeeded(); createUserDetailsServiceFactory(); createExceptionTranslationFilter(authenticationFilterSecurityContextHolderStrategyRef); @@ -635,13 +635,13 @@ final class AuthenticationConfigBuilder { } } - private void createSaml2LogoutFilter() { + private void createSaml2LogoutFilter(BeanMetadataElement authenticationFilterSecurityContextHolderStrategyRef) { Element saml2LogoutElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.SAML2_LOGOUT); if (saml2LogoutElt == null) { return; } Saml2LogoutBeanDefinitionParser parser = new Saml2LogoutBeanDefinitionParser(this.logoutHandlers, - this.logoutSuccessHandler); + this.logoutSuccessHandler, authenticationFilterSecurityContextHolderStrategyRef); parser.parse(saml2LogoutElt, this.pc); BeanDefinition saml2LogoutFilter = parser.getLogoutFilter(); BeanDefinition saml2LogoutRequestFilter = parser.getLogoutRequestFilter(); diff --git a/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParser.java index 8c61129ea8..a735440594 100644 --- a/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParser.java @@ -77,10 +77,14 @@ final class Saml2LogoutBeanDefinitionParser implements BeanDefinitionParser { private BeanDefinition logoutFilter; + private BeanMetadataElement authenticationFilterSecurityContextHolderStrategy; + Saml2LogoutBeanDefinitionParser(ManagedList logoutHandlers, - BeanMetadataElement logoutSuccessHandler) { + BeanMetadataElement logoutSuccessHandler, + BeanMetadataElement authenticationFilterSecurityContextHolderStrategy) { this.logoutHandlers = logoutHandlers; this.logoutSuccessHandler = logoutSuccessHandler; + this.authenticationFilterSecurityContextHolderStrategy = authenticationFilterSecurityContextHolderStrategy; } @Override @@ -119,7 +123,10 @@ final class Saml2LogoutBeanDefinitionParser implements BeanDefinitionParser { this.logoutRequestFilter = BeanDefinitionBuilder.rootBeanDefinition(Saml2LogoutRequestFilter.class) .addConstructorArgValue(registrations).addConstructorArgValue(logoutRequestValidator) .addConstructorArgValue(logoutResponseResolver).addConstructorArgValue(this.logoutHandlers) - .addPropertyValue("logoutRequestMatcher", logoutRequestMatcher).getBeanDefinition(); + .addPropertyValue("logoutRequestMatcher", logoutRequestMatcher) + .addPropertyValue("securityContextHolderStrategy", + this.authenticationFilterSecurityContextHolderStrategy) + .getBeanDefinition(); BeanMetadataElement logoutResponseValidator = Saml2LogoutBeanDefinitionParserUtils .getLogoutResponseValidator(element); BeanMetadataElement logoutRequestRepository = Saml2LogoutBeanDefinitionParserUtils diff --git a/config/src/test/java/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserTests.java index 7e28395ed2..d93e75433b 100644 --- a/config/src/test/java/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserTests.java @@ -32,6 +32,7 @@ import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.saml2.core.Saml2ParameterNames; import org.springframework.security.saml2.core.Saml2Utils; import org.springframework.security.saml2.core.TestSaml2X509Credentials; @@ -61,6 +62,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -178,6 +180,23 @@ public class Saml2LoginBeanDefinitionParserTests { assertThat(authentication.getPrincipal()).isInstanceOf(Saml2AuthenticatedPrincipal.class); } + @Test + public void authenticateWhenCustomSecurityContextHolderStrategyThenUses() throws Exception { + this.spring.configLocations(this.xml("WithCustomSecurityContextHolderStrategy")).autowire(); + RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistrationWithVerifyingCredential(); + // @formatter:off + this.mvc.perform(post("/login/saml2/sso/" + relyingPartyRegistration.getRegistrationId()).param(Saml2ParameterNames.SAML_RESPONSE, SIGNED_RESPONSE)) + .andDo(MockMvcResultHandlers.print()) + .andExpect(status().is2xxSuccessful()); + // @formatter:on + ArgumentCaptor authenticationCaptor = ArgumentCaptor.forClass(Authentication.class); + verify(this.authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), authenticationCaptor.capture()); + Authentication authentication = authenticationCaptor.getValue(); + assertThat(authentication.getPrincipal()).isInstanceOf(Saml2AuthenticatedPrincipal.class); + SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class); + verify(strategy, atLeastOnce()).getContext(); + } + @Test public void authenticateWhenAuthenticationResponseValidThenAuthenticationSuccessEventPublished() throws Exception { this.spring.configLocations(this.xml("WithCustomRelyingPartyRepository")).autowire(); diff --git a/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java index 4dc54b43f0..f7ab2e75a1 100644 --- a/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java @@ -35,6 +35,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.saml2.core.Saml2Utils; import org.springframework.security.saml2.provider.service.authentication.DefaultSaml2AuthenticatedPrincipal; import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; @@ -63,6 +64,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.verify; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; @@ -233,6 +235,23 @@ public class Saml2LogoutBeanDefinitionParserTests { assertThat(location).startsWith("https://ap.example.org/logout/saml2/response"); } + @Test + public void saml2LogoutRequestWhenCustomSecurityContextHolderStrategyThenUses() throws Exception { + this.spring.configLocations(this.xml("WithSecurityContextHolderStrategy")).autowire(); + DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("user", + Collections.emptyMap()); + principal.setRelyingPartyRegistrationId("get"); + Saml2Authentication user = new Saml2Authentication(principal, "response", + AuthorityUtils.createAuthorityList("ROLE_USER")); + MvcResult result = this.mvc.perform(get("/logout/saml2/slo").param("SAMLRequest", this.apLogoutRequest) + .param("RelayState", this.apLogoutRequestRelayState).param("SigAlg", this.apLogoutRequestSigAlg) + .param("Signature", this.apLogoutRequestSignature).with(samlQueryString()).with(authentication(user))) + .andExpect(status().isFound()).andReturn(); + String location = result.getResponse().getHeader("Location"); + assertThat(location).startsWith("https://ap.example.org/logout/saml2/response"); + verify(getBean(SecurityContextHolderStrategy.class), atLeastOnce()).getContext(); + } + @Test public void saml2LogoutRequestWhenNoRegistrationThen400() throws Exception { this.spring.configLocations(this.xml("Default")).autowire(); diff --git a/config/src/test/resources/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserTests-WithCustomSecurityContextHolderStrategy.xml b/config/src/test/resources/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserTests-WithCustomSecurityContextHolderStrategy.xml new file mode 100644 index 0000000000..b33a7da30e --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserTests-WithCustomSecurityContextHolderStrategy.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/config/src/test/resources/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests-WithSecurityContextHolderStrategy.xml b/config/src/test/resources/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests-WithSecurityContextHolderStrategy.xml new file mode 100644 index 0000000000..9b0425c598 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests-WithSecurityContextHolderStrategy.xml @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + + + +