diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java index 21efe4f5f8..6bfb802e55 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java @@ -49,6 +49,8 @@ import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.core.userdetails.cache.NullUserCache; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.filter.GenericFilterBean; @@ -106,6 +108,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes private boolean createAuthenticatedToken = false; + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + @Override public void afterPropertiesSet() { Assert.notNull(this.userDetailsService, "A UserDetailsService is required"); @@ -192,6 +196,7 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authentication); SecurityContextHolder.setContext(context); + this.securityContextRepository.saveContext(context, request, response); chain.doFilter(request, response); } @@ -271,6 +276,18 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes this.createAuthenticatedToken = createAuthenticatedToken; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + private class DigestData { private final String username; diff --git a/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java index 6168db7143..58211f2377 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java @@ -29,6 +29,7 @@ import org.apache.commons.codec.digest.DigestUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -40,10 +41,12 @@ import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.cache.NullUserCache; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -389,4 +392,25 @@ public class DigestAuthenticationFilterTests { assertThat(existingAuthentication).isSameAs(existingContext.getAuthentication()); } + @Test + public void testSecurityContextRepository() throws Exception { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + ArgumentCaptor contextArg = ArgumentCaptor.forClass(SecurityContext.class); + String responseDigest = DigestAuthUtils.generateDigest(false, USERNAME, REALM, PASSWORD, "GET", REQUEST_URI, + QOP, NONCE, NC, CNONCE); + this.request.addHeader("Authorization", + createAuthorizationHeader(USERNAME, REALM, NONCE, REQUEST_URI, responseDigest, QOP, NC, CNONCE)); + this.filter.setSecurityContextRepository(securityContextRepository); + this.filter.setCreateAuthenticatedToken(true); + MockHttpServletResponse response = executeFilterInContainerSimulator(this.filter, this.request, true); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull(); + assertThat(((UserDetails) SecurityContextHolder.getContext().getAuthentication().getPrincipal()).getUsername()) + .isEqualTo(USERNAME); + assertThat(SecurityContextHolder.getContext().getAuthentication().isAuthenticated()).isTrue(); + assertThat(SecurityContextHolder.getContext().getAuthentication().getAuthorities()) + .isEqualTo(AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); + verify(securityContextRepository).saveContext(contextArg.capture(), eq(this.request), eq(response)); + assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo(USERNAME); + } + }