Add SecurityContextHolderStrategy to Saml2

Issue gh-11060
This commit is contained in:
Josh Cummings 2022-06-21 16:32:45 -06:00
parent 9cd7c7b046
commit 3c8a80c364
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
2 changed files with 24 additions and 2 deletions

View File

@ -30,6 +30,7 @@ import org.springframework.core.log.LogMessage;
import org.springframework.http.MediaType;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
@ -64,6 +65,9 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
private final Log logger = LogFactory.getLog(getClass());
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private final Saml2LogoutRequestValidator logoutRequestValidator;
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
@ -108,7 +112,7 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
return;
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request,
getRegistrationId(authentication));
if (registration == null) {
@ -168,6 +172,17 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
this.logoutRequestMatcher = logoutRequestMatcher;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
private String getRegistrationId(Authentication authentication) {
if (authentication == null) {
return null;

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,6 +25,8 @@ import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidator;
@ -48,6 +50,8 @@ import static org.mockito.Mockito.verifyNoInteractions;
*/
public class Saml2LogoutRequestFilterTests {
SecurityContextHolderStrategy securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class);
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = mock(RelyingPartyRegistrationResolver.class);
Saml2LogoutRequestValidator logoutRequestValidator = mock(Saml2LogoutRequestValidator.class);
@ -94,6 +98,8 @@ public class Saml2LogoutRequestFilterTests {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
.assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)).build();
Authentication authentication = new TestingAuthenticationToken("user", "password");
given(this.securityContextHolderStrategy.getContext()).willReturn(new SecurityContextImpl(authentication));
this.logoutRequestProcessingFilter.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo");
request.setServletPath("/logout/saml2/slo");
@ -111,6 +117,7 @@ public class Saml2LogoutRequestFilterTests {
String content = response.getContentAsString();
assertThat(content).contains(Saml2ParameterNames.SAML_RESPONSE);
assertThat(content).contains(registration.getAssertingPartyDetails().getSingleLogoutServiceResponseLocation());
verify(this.securityContextHolderStrategy).getContext();
}
@Test