diff --git a/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java index d8573c0d70..9e4a29999e 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java @@ -36,6 +36,8 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.RememberMeServices; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; import org.springframework.web.filter.GenericFilterBean; @@ -73,6 +75,8 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements private RememberMeServices rememberMeServices; + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + public RememberMeAuthenticationFilter(AuthenticationManager authenticationManager, RememberMeServices rememberMeServices) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); @@ -114,6 +118,7 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements onSuccessfulAuthentication(request, response, rememberMeAuth); this.logger.debug(LogMessage.of(() -> "SecurityContextHolder populated with remember-me token: '" + SecurityContextHolder.getContext().getAuthentication() + "'")); + this.securityContextRepository.saveContext(context, request, response); if (this.eventPublisher != null) { this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent( SecurityContextHolder.getContext().getAuthentication(), this.getClass())); @@ -179,4 +184,16 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements this.successHandler = successHandler; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + } diff --git a/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java index de778a94fd..44839268bd 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java @@ -36,10 +36,12 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.NullRememberMeServices; import org.springframework.security.web.authentication.RememberMeServices; import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler; +import org.springframework.security.web.context.SecurityContextRepository; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -152,6 +154,23 @@ public class RememberMeAuthenticationFilterTests { verifyZeroInteractions(fc); } + @Test + public void securityContextRepositoryInvokedIfSet() throws Exception { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + AuthenticationManager am = mock(AuthenticationManager.class); + given(am.authenticate(this.remembered)).willReturn(this.remembered); + RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter(am, + new MockRememberMeServices(this.remembered)); + filter.setAuthenticationSuccessHandler(new SimpleUrlAuthenticationSuccessHandler("/target")); + filter.setSecurityContextRepository(securityContextRepository); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain fc = mock(FilterChain.class); + request.setRequestURI("x"); + filter.doFilter(request, response, fc); + verify(securityContextRepository).saveContext(any(), eq(request), eq(response)); + } + private class MockRememberMeServices implements RememberMeServices { private Authentication authToReturn;