From 4462b73fd9ea0963da81716246ce2b96c7f61760 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Fri, 25 Feb 2022 15:47:00 -0600 Subject: [PATCH] AbstractPreAuthenticatedProcessingFilter.securityContextRepository Issue gh-10953 --- ...tractPreAuthenticatedProcessingFilter.java | 17 +++++++++++ ...PreAuthenticatedProcessingFilterTests.java | 29 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java index ebc45501ef..9d4db0287b 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java @@ -40,6 +40,8 @@ import org.springframework.security.web.WebAttributes; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.web.filter.GenericFilterBean; @@ -104,6 +106,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi private RequestMatcher requiresAuthenticationRequestMatcher = new PreAuthenticatedProcessingRequestMatcher(); + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + /** * Check whether all required properties have been set. */ @@ -210,6 +214,7 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authResult); SecurityContextHolder.setContext(context); + this.securityContextRepository.saveContext(context, request, response); if (this.eventPublisher != null) { this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass())); } @@ -242,6 +247,18 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi this.eventPublisher = anApplicationEventPublisher; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + /** * @param authenticationDetailsSource The AuthenticationDetailsSource to use */ diff --git a/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java index 1c81209c75..e8dc80f882 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java @@ -23,6 +23,7 @@ import jakarta.servlet.http.HttpServletRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.springframework.mock.web.MockFilterChain; @@ -34,17 +35,20 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.User; import org.springframework.security.web.WebAttributes; import org.springframework.security.web.authentication.ForwardAuthenticationFailureHandler; import org.springframework.security.web.authentication.ForwardAuthenticationSuccessHandler; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -210,6 +214,31 @@ public class AbstractPreAuthenticatedProcessingFilterTests { assertThat(response.getForwardedUrl()).isEqualTo("/forwardUrl"); } + @Test + public void securityContextRepository() throws Exception { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + Object currentPrincipal = "currentUser"; + TestingAuthenticationToken authRequest = new TestingAuthenticationToken(currentPrincipal, "something", + "ROLE_USER"); + SecurityContextHolder.getContext().setAuthentication(authRequest); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain chain = new MockFilterChain(); + ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter(); + filter.setSecurityContextRepository(securityContextRepository); + filter.setAuthenticationSuccessHandler(new ForwardAuthenticationSuccessHandler("/forwardUrl")); + filter.setCheckForPrincipalChanges(true); + filter.principal = "newUser"; + AuthenticationManager am = mock(AuthenticationManager.class); + given(am.authenticate(any())).willReturn(authRequest); + filter.setAuthenticationManager(am); + filter.afterPropertiesSet(); + filter.doFilter(request, response, chain); + ArgumentCaptor contextArg = ArgumentCaptor.forClass(SecurityContext.class); + verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response)); + assertThat(contextArg.getValue().getAuthentication().getPrincipal()).isEqualTo(authRequest.getName()); + } + @Test public void callsAuthenticationFailureHandlerOnFailedAuthentication() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest();