diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java index b0d36b116a..199de3da95 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java @@ -38,6 +38,8 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; @@ -75,6 +77,8 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter private AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + /** * Construct a {@code BearerTokenAuthenticationFilter} using the provided parameter(s) * @param authenticationManagerResolver @@ -131,6 +135,7 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authenticationResult); SecurityContextHolder.setContext(context); + this.securityContextRepository.saveContext(context, request, response); if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authenticationResult)); } @@ -143,6 +148,18 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter } } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + /** * Set the {@link BearerTokenResolver} to use. Defaults to * {@link DefaultBearerTokenResolver}. diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java index e1bc8a6550..d3fa9eadb0 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java @@ -36,18 +36,23 @@ import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManagerResolver; import org.springframework.security.authentication.AuthenticationServiceException; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.BearerTokenError; import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.context.SecurityContextRepository; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -102,6 +107,26 @@ public class BearerTokenAuthenticationFilterTests { assertThat(captor.getValue().getPrincipal()).isEqualTo("token"); } + @Test + public void doFilterWhenSecurityContextRepositoryThenSaves() throws ServletException, IOException { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + String token = "token"; + given(this.bearerTokenResolver.resolve(this.request)).willReturn(token); + TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("test", "password"); + given(this.authenticationManager.authenticate(any())).willReturn(expectedAuthentication); + BearerTokenAuthenticationFilter filter = addMocks( + new BearerTokenAuthenticationFilter(this.authenticationManager)); + filter.setSecurityContextRepository(securityContextRepository); + filter.doFilter(this.request, this.response, this.filterChain); + ArgumentCaptor captor = ArgumentCaptor + .forClass(BearerTokenAuthenticationToken.class); + verify(this.authenticationManager).authenticate(captor.capture()); + assertThat(captor.getValue().getPrincipal()).isEqualTo(token); + ArgumentCaptor contextArg = ArgumentCaptor.forClass(SecurityContext.class); + verify(securityContextRepository).saveContext(contextArg.capture(), eq(this.request), eq(this.response)); + assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo(expectedAuthentication.getName()); + } + @Test public void doFilterWhenUsingAuthenticationManagerResolverThenAuthenticates() throws Exception { BearerTokenAuthenticationFilter filter = addMocks( diff --git a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java index 43a7b04d79..839f648713 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java @@ -32,6 +32,8 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -74,6 +76,8 @@ public class AuthenticationFilter extends OncePerRequestFilter { private AuthenticationFailureHandler failureHandler = new AuthenticationEntryPointFailureHandler( new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED)); + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + private AuthenticationManagerResolver authenticationManagerResolver; public AuthenticationFilter(AuthenticationManager authenticationManager, @@ -135,6 +139,18 @@ public class AuthenticationFilter extends OncePerRequestFilter { this.authenticationManagerResolver = authenticationManagerResolver; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { @@ -173,6 +189,7 @@ public class AuthenticationFilter extends OncePerRequestFilter { SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authentication); SecurityContextHolder.setContext(context); + this.securityContextRepository.saveContext(context, request, response); this.successHandler.onAuthenticationSuccess(request, response, chain, authentication); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java index ebc45501ef..9d4db0287b 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java @@ -40,6 +40,8 @@ import org.springframework.security.web.WebAttributes; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.web.filter.GenericFilterBean; @@ -104,6 +106,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi private RequestMatcher requiresAuthenticationRequestMatcher = new PreAuthenticatedProcessingRequestMatcher(); + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + /** * Check whether all required properties have been set. */ @@ -210,6 +214,7 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authResult); SecurityContextHolder.setContext(context); + this.securityContextRepository.saveContext(context, request, response); if (this.eventPublisher != null) { this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass())); } @@ -242,6 +247,18 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi this.eventPublisher = anApplicationEventPublisher; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + /** * @param authenticationDetailsSource The AuthenticationDetailsSource to use */ diff --git a/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java index 000fef5e5c..aea8c80419 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java @@ -36,6 +36,8 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.RememberMeServices; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; import org.springframework.web.filter.GenericFilterBean; @@ -73,6 +75,8 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements private RememberMeServices rememberMeServices; + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + public RememberMeAuthenticationFilter(AuthenticationManager authenticationManager, RememberMeServices rememberMeServices) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); @@ -114,6 +118,7 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements onSuccessfulAuthentication(request, response, rememberMeAuth); this.logger.debug(LogMessage.of(() -> "SecurityContextHolder populated with remember-me token: '" + SecurityContextHolder.getContext().getAuthentication() + "'")); + this.securityContextRepository.saveContext(context, request, response); if (this.eventPublisher != null) { this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent( SecurityContextHolder.getContext().getAuthentication(), this.getClass())); @@ -179,4 +184,16 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements this.successHandler = successHandler; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + } diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java index 82c0d617ff..c014455bfa 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java @@ -36,6 +36,8 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.NullRememberMeServices; import org.springframework.security.web.authentication.RememberMeServices; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; @@ -103,6 +105,8 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { private BasicAuthenticationConverter authenticationConverter = new BasicAuthenticationConverter(); + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + /** * Creates an instance which will authenticate against the supplied * {@code AuthenticationManager} and which will ignore failed authentication attempts, @@ -131,6 +135,18 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { this.authenticationEntryPoint = authenticationEntryPoint; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + @Override public void afterPropertiesSet() { Assert.notNull(this.authenticationManager, "An AuthenticationManager is required"); @@ -161,6 +177,7 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult)); } this.rememberMeServices.loginSuccess(request, response, authResult); + this.securityContextRepository.saveContext(context, request, response); onSuccessfulAuthentication(request, response, authResult); } } diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java index 70a25379fe..34cc3c7162 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java @@ -49,6 +49,8 @@ import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.core.userdetails.cache.NullUserCache; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.filter.GenericFilterBean; @@ -106,6 +108,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes private boolean createAuthenticatedToken = false; + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + @Override public void afterPropertiesSet() { Assert.notNull(this.userDetailsService, "A UserDetailsService is required"); @@ -192,6 +196,7 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authentication); SecurityContextHolder.setContext(context); + this.securityContextRepository.saveContext(context, request, response); chain.doFilter(request, response); } @@ -271,6 +276,18 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes this.createAuthenticatedToken = createAuthenticatedToken; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + private class DigestData { private final String username; diff --git a/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java index cd737d9b7a..22e2095c7b 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java @@ -25,6 +25,7 @@ import jakarta.servlet.http.HttpServletRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @@ -38,7 +39,9 @@ import org.springframework.security.authentication.AuthenticationManagerResolver import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.RequestMatcher; import static org.assertj.core.api.Assertions.assertThat; @@ -256,4 +259,36 @@ public class AuthenticationFilterTests { assertThat(session.getId()).isNotEqualTo(sessionId); } + @Test + public void filterWhenSuccessfulAuthenticationThenNoSessionCreated() throws Exception { + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER"); + given(this.authenticationConverter.convert(any())).willReturn(authentication); + given(this.authenticationManager.authenticate(any())).willReturn(authentication); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = new MockFilterChain(); + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager, + this.authenticationConverter); + filter.doFilter(request, response, chain); + assertThat(request.getSession(false)).isNull(); + } + + @Test + public void filterWhenCustomSecurityContextRepositoryAndSuccessfulAuthenticationRepositoryUsed() throws Exception { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + ArgumentCaptor securityContextArg = ArgumentCaptor.forClass(SecurityContext.class); + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER"); + given(this.authenticationConverter.convert(any())).willReturn(authentication); + given(this.authenticationManager.authenticate(any())).willReturn(authentication); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = new MockFilterChain(); + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager, + this.authenticationConverter); + filter.setSecurityContextRepository(securityContextRepository); + filter.doFilter(request, response, chain); + verify(securityContextRepository).saveContext(securityContextArg.capture(), eq(request), eq(response)); + assertThat(securityContextArg.getValue().getAuthentication()).isEqualTo(authentication); + } + } diff --git a/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java index 1c81209c75..e8dc80f882 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java @@ -23,6 +23,7 @@ import jakarta.servlet.http.HttpServletRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.springframework.mock.web.MockFilterChain; @@ -34,17 +35,20 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.User; import org.springframework.security.web.WebAttributes; import org.springframework.security.web.authentication.ForwardAuthenticationFailureHandler; import org.springframework.security.web.authentication.ForwardAuthenticationSuccessHandler; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -210,6 +214,31 @@ public class AbstractPreAuthenticatedProcessingFilterTests { assertThat(response.getForwardedUrl()).isEqualTo("/forwardUrl"); } + @Test + public void securityContextRepository() throws Exception { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + Object currentPrincipal = "currentUser"; + TestingAuthenticationToken authRequest = new TestingAuthenticationToken(currentPrincipal, "something", + "ROLE_USER"); + SecurityContextHolder.getContext().setAuthentication(authRequest); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain chain = new MockFilterChain(); + ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter(); + filter.setSecurityContextRepository(securityContextRepository); + filter.setAuthenticationSuccessHandler(new ForwardAuthenticationSuccessHandler("/forwardUrl")); + filter.setCheckForPrincipalChanges(true); + filter.principal = "newUser"; + AuthenticationManager am = mock(AuthenticationManager.class); + given(am.authenticate(any())).willReturn(authRequest); + filter.setAuthenticationManager(am); + filter.afterPropertiesSet(); + filter.doFilter(request, response, chain); + ArgumentCaptor contextArg = ArgumentCaptor.forClass(SecurityContext.class); + verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response)); + assertThat(contextArg.getValue().getAuthentication().getPrincipal()).isEqualTo(authRequest.getName()); + } + @Test public void callsAuthenticationFailureHandlerOnFailedAuthentication() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); diff --git a/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java index 8d036b7388..a4ec54d3ed 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java @@ -36,10 +36,12 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.NullRememberMeServices; import org.springframework.security.web.authentication.RememberMeServices; import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler; +import org.springframework.security.web.context.SecurityContextRepository; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -152,6 +154,23 @@ public class RememberMeAuthenticationFilterTests { verifyZeroInteractions(fc); } + @Test + public void securityContextRepositoryInvokedIfSet() throws Exception { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + AuthenticationManager am = mock(AuthenticationManager.class); + given(am.authenticate(this.remembered)).willReturn(this.remembered); + RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter(am, + new MockRememberMeServices(this.remembered)); + filter.setAuthenticationSuccessHandler(new SimpleUrlAuthenticationSuccessHandler("/target")); + filter.setSecurityContextRepository(securityContextRepository); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain fc = mock(FilterChain.class); + request.setRequestURI("x"); + filter.doFilter(request, response, fc); + verify(securityContextRepository).saveContext(any(), eq(request), eq(response)); + } + private class MockRememberMeServices implements RememberMeServices { private Authentication authToReturn; diff --git a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java index 420f608636..a274e80f5b 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java @@ -27,6 +27,7 @@ import org.apache.commons.codec.binary.Base64; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -36,8 +37,10 @@ import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.WebAuthenticationDetails; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.web.util.WebUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -364,4 +367,25 @@ public class BasicAuthenticationFilterTests { assertThat(response.getStatus()).isEqualTo(401); } + @Test + public void requestWhenSecurityContextRepository() throws Exception { + ArgumentCaptor contextArg = ArgumentCaptor.forClass(SecurityContext.class); + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + this.filter.setSecurityContextRepository(securityContextRepository); + String token = "rod:koala"; + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes()))); + request.setServletPath("/some_file.html"); + MockHttpServletResponse response = new MockHttpServletResponse(); + // Test + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + FilterChain chain = mock(FilterChain.class); + this.filter.doFilter(request, response, chain); + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull(); + assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("rod"); + verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response)); + assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo("rod"); + } + } diff --git a/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java index f0f6e9a08f..badce48834 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java @@ -29,6 +29,7 @@ import org.apache.commons.codec.digest.DigestUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -40,10 +41,12 @@ import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.cache.NullUserCache; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -389,4 +392,25 @@ public class DigestAuthenticationFilterTests { assertThat(existingAuthentication).isSameAs(existingContext.getAuthentication()); } + @Test + public void testSecurityContextRepository() throws Exception { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + ArgumentCaptor contextArg = ArgumentCaptor.forClass(SecurityContext.class); + String responseDigest = DigestAuthUtils.generateDigest(false, USERNAME, REALM, PASSWORD, "GET", REQUEST_URI, + QOP, NONCE, NC, CNONCE); + this.request.addHeader("Authorization", + createAuthorizationHeader(USERNAME, REALM, NONCE, REQUEST_URI, responseDigest, QOP, NC, CNONCE)); + this.filter.setSecurityContextRepository(securityContextRepository); + this.filter.setCreateAuthenticatedToken(true); + MockHttpServletResponse response = executeFilterInContainerSimulator(this.filter, this.request, true); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull(); + assertThat(((UserDetails) SecurityContextHolder.getContext().getAuthentication().getPrincipal()).getUsername()) + .isEqualTo(USERNAME); + assertThat(SecurityContextHolder.getContext().getAuthentication().isAuthenticated()).isTrue(); + assertThat(SecurityContextHolder.getContext().getAuthentication().getAuthorities()) + .isEqualTo(AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); + verify(securityContextRepository).saveContext(contextArg.capture(), eq(this.request), eq(response)); + assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo(USERNAME); + } + }