diff --git a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java index 43a7b04d79..839f648713 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java @@ -32,6 +32,8 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -74,6 +76,8 @@ public class AuthenticationFilter extends OncePerRequestFilter { private AuthenticationFailureHandler failureHandler = new AuthenticationEntryPointFailureHandler( new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED)); + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + private AuthenticationManagerResolver authenticationManagerResolver; public AuthenticationFilter(AuthenticationManager authenticationManager, @@ -135,6 +139,18 @@ public class AuthenticationFilter extends OncePerRequestFilter { this.authenticationManagerResolver = authenticationManagerResolver; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { @@ -173,6 +189,7 @@ public class AuthenticationFilter extends OncePerRequestFilter { SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authentication); SecurityContextHolder.setContext(context); + this.securityContextRepository.saveContext(context, request, response); this.successHandler.onAuthenticationSuccess(request, response, chain, authentication); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java index cd737d9b7a..22e2095c7b 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java @@ -25,6 +25,7 @@ import jakarta.servlet.http.HttpServletRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @@ -38,7 +39,9 @@ import org.springframework.security.authentication.AuthenticationManagerResolver import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.RequestMatcher; import static org.assertj.core.api.Assertions.assertThat; @@ -256,4 +259,36 @@ public class AuthenticationFilterTests { assertThat(session.getId()).isNotEqualTo(sessionId); } + @Test + public void filterWhenSuccessfulAuthenticationThenNoSessionCreated() throws Exception { + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER"); + given(this.authenticationConverter.convert(any())).willReturn(authentication); + given(this.authenticationManager.authenticate(any())).willReturn(authentication); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = new MockFilterChain(); + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager, + this.authenticationConverter); + filter.doFilter(request, response, chain); + assertThat(request.getSession(false)).isNull(); + } + + @Test + public void filterWhenCustomSecurityContextRepositoryAndSuccessfulAuthenticationRepositoryUsed() throws Exception { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + ArgumentCaptor securityContextArg = ArgumentCaptor.forClass(SecurityContext.class); + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER"); + given(this.authenticationConverter.convert(any())).willReturn(authentication); + given(this.authenticationManager.authenticate(any())).willReturn(authentication); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = new MockFilterChain(); + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager, + this.authenticationConverter); + filter.setSecurityContextRepository(securityContextRepository); + filter.doFilter(request, response, chain); + verify(securityContextRepository).saveContext(securityContextArg.capture(), eq(request), eq(response)); + assertThat(securityContextArg.getValue().getAuthentication()).isEqualTo(authentication); + } + }