diff --git a/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java b/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java index 737aa6a9ea..310bf0c516 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java @@ -58,6 +58,8 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler; import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.context.RequestAttributeSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -142,6 +144,8 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv private AuthenticationFailureHandler failureHandler; + private SecurityContextRepository securityContextRepository = new RequestAttributeSecurityContextRepository(); + @Override public void afterPropertiesSet() { Assert.notNull(this.userDetailsService, "userDetailsService must be specified"); @@ -179,6 +183,7 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv context.setAuthentication(targetUser); SecurityContextHolder.setContext(context); this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", targetUser)); + this.securityContextRepository.saveContext(context, request, response); // redirect to target url this.successHandler.onAuthenticationSuccess(request, response, targetUser); } @@ -196,6 +201,7 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv context.setAuthentication(originalUser); SecurityContextHolder.setContext(context); this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", originalUser)); + this.securityContextRepository.saveContext(context, request, response); // redirect to target url this.successHandler.onAuthenticationSuccess(request, response, originalUser); return; @@ -510,6 +516,19 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv this.switchAuthorityRole = switchAuthorityRole; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * switch user success. The default is + * {@link RequestAttributeSecurityContextRepository}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + * @since 5.7.7 + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + private static RequestMatcher createMatcher(String pattern) { return new AntPathRequestMatcher(pattern, "POST", true, new UrlPathHelper()); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java index 8959f099ba..786bbc3054 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java @@ -16,15 +16,18 @@ package org.springframework.security.web.authentication.switchuser; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import javax.servlet.FilterChain; +import javax.servlet.ServletException; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AccountExpiredException; @@ -44,11 +47,15 @@ import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.util.FieldUtils; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler; +import org.springframework.security.web.context.RequestAttributeSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.AnyRequestMatcher; +import org.springframework.test.util.ReflectionTestUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -483,6 +490,59 @@ public class SwitchUserFilterTests { filter.setSwitchFailureUrl("/foo"); } + @Test + void filterWhenDefaultSecurityContextRepositoryThenRequestAttributeRepository() { + SwitchUserFilter switchUserFilter = new SwitchUserFilter(); + assertThat(ReflectionTestUtils.getField(switchUserFilter, "securityContextRepository")) + .isInstanceOf(RequestAttributeSecurityContextRepository.class); + } + + @Test + void doFilterWhenSwitchUserThenSaveSecurityContext() throws ServletException, IOException { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain filterChain = new MockFilterChain(); + request.setParameter(SwitchUserFilter.SPRING_SECURITY_SWITCH_USERNAME_KEY, "jacklord"); + request.setRequestURI("/login/impersonate"); + SwitchUserFilter filter = new SwitchUserFilter(); + filter.setSecurityContextRepository(securityContextRepository); + filter.setUserDetailsService(new MockUserDetailsService()); + filter.setTargetUrl("/target"); + filter.afterPropertiesSet(); + + filter.doFilter(request, response, filterChain); + + verify(securityContextRepository).saveContext(any(), any(), any()); + } + + @Test + void doFilterWhenExitUserThenSaveSecurityContext() throws ServletException, IOException { + UsernamePasswordAuthenticationToken source = UsernamePasswordAuthenticationToken.authenticated("dano", + "hawaii50", ROLES_12); + // set current user (Admin) + List adminAuths = new ArrayList<>(ROLES_12); + adminAuths.add(new SwitchUserGrantedAuthority("PREVIOUS_ADMINISTRATOR", source)); + UsernamePasswordAuthenticationToken admin = UsernamePasswordAuthenticationToken.authenticated("jacklord", + "hawaii50", adminAuths); + SecurityContextHolder.getContext().setAuthentication(admin); + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain filterChain = new MockFilterChain(); + request.setParameter(SwitchUserFilter.SPRING_SECURITY_SWITCH_USERNAME_KEY, "jacklord"); + request.setRequestURI("/logout/impersonate"); + SwitchUserFilter filter = new SwitchUserFilter(); + filter.setSecurityContextRepository(securityContextRepository); + filter.setUserDetailsService(new MockUserDetailsService()); + filter.setTargetUrl("/target"); + filter.afterPropertiesSet(); + + filter.doFilter(request, response, filterChain); + + verify(securityContextRepository).saveContext(any(), any(), any()); + } + private class MockUserDetailsService implements UserDetailsService { private String password = "hawaii50";