diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java index 9d639adf15..208c94c54f 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java @@ -36,6 +36,8 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.NullRememberMeServices; import org.springframework.security.web.authentication.RememberMeServices; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; @@ -103,6 +105,8 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { private BasicAuthenticationConverter authenticationConverter = new BasicAuthenticationConverter(); + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + /** * Creates an instance which will authenticate against the supplied * {@code AuthenticationManager} and which will ignore failed authentication attempts, @@ -131,6 +135,18 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { this.authenticationEntryPoint = authenticationEntryPoint; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + @Override public void afterPropertiesSet() { Assert.notNull(this.authenticationManager, "An AuthenticationManager is required"); @@ -161,6 +177,7 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult)); } this.rememberMeServices.loginSuccess(request, response, authResult); + this.securityContextRepository.saveContext(context, request, response); onSuccessfulAuthentication(request, response, authResult); } } diff --git a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java index 13b265b8e0..dbb78c98f6 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java @@ -27,6 +27,7 @@ import org.apache.commons.codec.binary.Base64; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -36,8 +37,10 @@ import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.WebAuthenticationDetails; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.web.util.WebUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -364,4 +367,25 @@ public class BasicAuthenticationFilterTests { assertThat(response.getStatus()).isEqualTo(401); } + @Test + public void requestWhenSecurityContextRepository() throws Exception { + ArgumentCaptor contextArg = ArgumentCaptor.forClass(SecurityContext.class); + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + this.filter.setSecurityContextRepository(securityContextRepository); + String token = "rod:koala"; + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes()))); + request.setServletPath("/some_file.html"); + MockHttpServletResponse response = new MockHttpServletResponse(); + // Test + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + FilterChain chain = mock(FilterChain.class); + this.filter.doFilter(request, response, chain); + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull(); + assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("rod"); + verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response)); + assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo("rod"); + } + }