diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilter.java index 603fbac8eb..816b5e8bcb 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilter.java @@ -30,6 +30,7 @@ import org.springframework.core.log.LogMessage; import org.springframework.http.MediaType; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.saml2.core.Saml2ParameterNames; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal; import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest; @@ -64,6 +65,9 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter { private final Log logger = LogFactory.getLog(getClass()); + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private final Saml2LogoutRequestValidator logoutRequestValidator; private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; @@ -108,7 +112,7 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter { return; } - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request, getRegistrationId(authentication)); if (registration == null) { @@ -168,6 +172,17 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter { this.logoutRequestMatcher = logoutRequestMatcher; } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + private String getRegistrationId(Authentication authentication) { if (authentication == null) { return null; diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java index bb604c4775..b185da6391 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2ParameterNames; import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidator; @@ -48,6 +50,8 @@ import static org.mockito.Mockito.verifyNoInteractions; */ public class Saml2LogoutRequestFilterTests { + SecurityContextHolderStrategy securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class); + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = mock(RelyingPartyRegistrationResolver.class); Saml2LogoutRequestValidator logoutRequestValidator = mock(Saml2LogoutRequestValidator.class); @@ -94,6 +98,8 @@ public class Saml2LogoutRequestFilterTests { RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() .assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)).build(); Authentication authentication = new TestingAuthenticationToken("user", "password"); + given(this.securityContextHolderStrategy.getContext()).willReturn(new SecurityContextImpl(authentication)); + this.logoutRequestProcessingFilter.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); SecurityContextHolder.getContext().setAuthentication(authentication); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); request.setServletPath("/logout/saml2/slo"); @@ -111,6 +117,7 @@ public class Saml2LogoutRequestFilterTests { String content = response.getContentAsString(); assertThat(content).contains(Saml2ParameterNames.SAML_RESPONSE); assertThat(content).contains(registration.getAssertingPartyDetails().getSingleLogoutServiceResponseLocation()); + verify(this.securityContextHolderStrategy).getContext(); } @Test