From cbba7ea4de069638d4799e14c837ec1c992b0a87 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Fri, 18 Feb 2022 15:47:47 -0600 Subject: [PATCH] AbstractAuthenticationProcessingFilter.securityContextRepository Issue gh-10953 --- ...bstractAuthenticationProcessingFilter.java | 17 +++++++++ ...ctAuthenticationProcessingFilterTests.java | 35 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java index a8ba24df5e..e7abefa6fd 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java @@ -42,6 +42,8 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.session.NullAuthenticatedSessionStrategy; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -134,6 +136,8 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt private AuthenticationFailureHandler failureHandler = new SimpleUrlAuthenticationFailureHandler(); + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + /** * @param defaultFilterProcessesUrl the default value for filterProcessesUrl. */ @@ -314,6 +318,7 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authResult); SecurityContextHolder.setContext(context); + this.securityContextRepository.saveContext(context, request, response); if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult)); } @@ -435,6 +440,18 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt this.failureHandler = failureHandler; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + protected AuthenticationSuccessHandler getSuccessHandler() { return this.successHandler; } diff --git a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java index f7cdabce5c..c8b5816381 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java @@ -27,6 +27,7 @@ import org.apache.commons.logging.Log; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockHttpServletRequest; @@ -34,14 +35,17 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.InternalAuthenticationServiceException; +import org.springframework.security.authentication.TestAuthentication; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServicesTests; import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.firewall.DefaultHttpFirewall; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -322,6 +326,37 @@ public class AbstractAuthenticationProcessingFilterTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull(); } + @Test + public void testSuccessfulAuthenticationThenDefaultDoesNotCreateSession() throws Exception { + Authentication authentication = TestAuthentication.authenticatedUser(); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain chain = new MockFilterChain(false); + MockAuthenticationFilter filter = new MockAuthenticationFilter(); + + filter.successfulAuthentication(request, response, chain, authentication); + + assertThat(request.getSession(false)).isNull(); + } + + @Test + public void testSuccessfulAuthenticationWhenCustomSecurityContextRepositoryThenAuthenticationSaved() + throws Exception { + ArgumentCaptor contextCaptor = ArgumentCaptor.forClass(SecurityContext.class); + SecurityContextRepository repository = mock(SecurityContextRepository.class); + Authentication authentication = TestAuthentication.authenticatedUser(); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain chain = new MockFilterChain(false); + MockAuthenticationFilter filter = new MockAuthenticationFilter(); + filter.setSecurityContextRepository(repository); + + filter.successfulAuthentication(request, response, chain, authentication); + + verify(repository).saveContext(contextCaptor.capture(), eq(request), eq(response)); + assertThat(contextCaptor.getValue().getAuthentication()).isEqualTo(authentication); + } + @Test public void testFailedAuthenticationInvokesFailureHandler() throws Exception { // Setup our HTTP request