diff --git a/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java b/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java index 737aa6a9ea..25d20a696a 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java @@ -49,6 +49,7 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.SpringSecurityMessageSource; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsChecker; import org.springframework.security.core.userdetails.UserDetailsService; @@ -114,6 +115,9 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv public static final String ROLE_PREVIOUS_ADMINISTRATOR = "ROLE_PREVIOUS_ADMINISTRATOR"; + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private ApplicationEventPublisher eventPublisher; private AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); @@ -175,9 +179,9 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv try { Authentication targetUser = attemptSwitchUser(request); // update the current context to the new target user - SecurityContext context = SecurityContextHolder.createEmptyContext(); + SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); context.setAuthentication(targetUser); - SecurityContextHolder.setContext(context); + this.securityContextHolderStrategy.setContext(context); this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", targetUser)); // redirect to target url this.successHandler.onAuthenticationSuccess(request, response, targetUser); @@ -192,9 +196,9 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv // get the original authentication object (if exists) Authentication originalUser = attemptExitUser(request); // update the current context back to the original user - SecurityContext context = SecurityContextHolder.createEmptyContext(); + SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); context.setAuthentication(originalUser); - SecurityContextHolder.setContext(context); + this.securityContextHolderStrategy.setContext(context); this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", originalUser)); // redirect to target url this.successHandler.onAuthenticationSuccess(request, response, originalUser); @@ -228,7 +232,7 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv // publish event if (this.eventPublisher != null) { this.eventPublisher.publishEvent(new AuthenticationSwitchUserEvent( - SecurityContextHolder.getContext().getAuthentication(), targetUser)); + this.securityContextHolderStrategy.getContext().getAuthentication(), targetUser)); } return targetUserRequest; } @@ -244,7 +248,7 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv protected Authentication attemptExitUser(HttpServletRequest request) throws AuthenticationCredentialsNotFoundException { // need to check to see if the current user has a SwitchUserGrantedAuthority - Authentication current = SecurityContextHolder.getContext().getAuthentication(); + Authentication current = this.securityContextHolderStrategy.getContext().getAuthentication(); if (current == null) { throw new AuthenticationCredentialsNotFoundException(this.messages .getMessage("SwitchUserFilter.noCurrentUser", "No current user associated with this request")); @@ -310,7 +314,7 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv return attemptExitUser(request); } catch (AuthenticationCredentialsNotFoundException ex) { - return SecurityContextHolder.getContext().getAuthentication(); + return this.securityContextHolderStrategy.getContext().getAuthentication(); } } @@ -510,6 +514,17 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv this.switchAuthorityRole = switchAuthorityRole; } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + private static RequestMatcher createMatcher(String pattern) { return new AntPathRequestMatcher(pattern, "POST", true, new UrlPathHelper()); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java index 8959f099ba..b55ca11677 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java @@ -36,7 +36,10 @@ import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; @@ -49,8 +52,10 @@ import org.springframework.security.web.util.matcher.AnyRequestMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; /** @@ -416,6 +421,21 @@ public class SwitchUserFilterTests { assertThat(AuthorityUtils.authorityListToSet(result.getAuthorities())).contains("ROLE_NEW"); } + @Test + public void doFilterWhenCustomSecurityContextRepositoryThenUses() { + SecurityContextHolderStrategy securityContextHolderStrategy = spy(new MockSecurityContextHolderStrategy( + UsernamePasswordAuthenticationToken.unauthenticated("dano", "hawaii50"))); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter(SwitchUserFilter.SPRING_SECURITY_SWITCH_USERNAME_KEY, "jacklord"); + SwitchUserFilter filter = new SwitchUserFilter(); + filter.setSecurityContextHolderStrategy(securityContextHolderStrategy); + filter.setUserDetailsService(new MockUserDetailsService()); + Authentication result = filter.attemptSwitchUser(request); + assertThat(result).isNotNull(); + assertThat(result.getName()).isEqualTo("jacklord"); + verify(securityContextHolderStrategy, atLeastOnce()).getContext(); + } + // SEC-1763 @Test public void nestedSwitchesAreNotAllowed() { @@ -512,4 +532,34 @@ public class SwitchUserFilterTests { } + static final class MockSecurityContextHolderStrategy implements SecurityContextHolderStrategy { + + private SecurityContext mock; + + private MockSecurityContextHolderStrategy(Authentication authentication) { + this.mock = new SecurityContextImpl(authentication); + } + + @Override + public void clearContext() { + this.mock = null; + } + + @Override + public SecurityContext getContext() { + return this.mock; + } + + @Override + public void setContext(SecurityContext context) { + this.mock = context; + } + + @Override + public SecurityContext createEmptyContext() { + return new SecurityContextImpl(); + } + + } + }