Add SecurityContextRepository to all Authentication Filters

Closes gh-10949
This commit is contained in:
Rob Winch 2022-03-09 15:48:03 -06:00
commit 28c7a4be11
12 changed files with 258 additions and 0 deletions

View File

@ -38,6 +38,8 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.context.NullSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
@ -75,6 +77,8 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
/**
* Construct a {@code BearerTokenAuthenticationFilter} using the provided parameter(s)
* @param authenticationManagerResolver
@ -131,6 +135,7 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authenticationResult);
SecurityContextHolder.setContext(context);
this.securityContextRepository.saveContext(context, request, response);
if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authenticationResult));
}
@ -143,6 +148,18 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
}
}
/**
* Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
* authentication success. The default action is not to save the
* {@link SecurityContext}.
* @param securityContextRepository the {@link SecurityContextRepository} to use.
* Cannot be null.
*/
public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
this.securityContextRepository = securityContextRepository;
}
/**
* Set the {@link BearerTokenResolver} to use. Defaults to
* {@link DefaultBearerTokenResolver}.

View File

@ -36,18 +36,23 @@ import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationManagerResolver;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
import org.springframework.security.oauth2.server.resource.BearerTokenError;
import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.context.SecurityContextRepository;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -102,6 +107,26 @@ public class BearerTokenAuthenticationFilterTests {
assertThat(captor.getValue().getPrincipal()).isEqualTo("token");
}
@Test
public void doFilterWhenSecurityContextRepositoryThenSaves() throws ServletException, IOException {
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
String token = "token";
given(this.bearerTokenResolver.resolve(this.request)).willReturn(token);
TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("test", "password");
given(this.authenticationManager.authenticate(any())).willReturn(expectedAuthentication);
BearerTokenAuthenticationFilter filter = addMocks(
new BearerTokenAuthenticationFilter(this.authenticationManager));
filter.setSecurityContextRepository(securityContextRepository);
filter.doFilter(this.request, this.response, this.filterChain);
ArgumentCaptor<BearerTokenAuthenticationToken> captor = ArgumentCaptor
.forClass(BearerTokenAuthenticationToken.class);
verify(this.authenticationManager).authenticate(captor.capture());
assertThat(captor.getValue().getPrincipal()).isEqualTo(token);
ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
verify(securityContextRepository).saveContext(contextArg.capture(), eq(this.request), eq(this.response));
assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo(expectedAuthentication.getName());
}
@Test
public void doFilterWhenUsingAuthenticationManagerResolverThenAuthenticates() throws Exception {
BearerTokenAuthenticationFilter filter = addMocks(

View File

@ -32,6 +32,8 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.context.NullSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
@ -74,6 +76,8 @@ public class AuthenticationFilter extends OncePerRequestFilter {
private AuthenticationFailureHandler failureHandler = new AuthenticationEntryPointFailureHandler(
new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED));
private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
private AuthenticationManagerResolver<HttpServletRequest> authenticationManagerResolver;
public AuthenticationFilter(AuthenticationManager authenticationManager,
@ -135,6 +139,18 @@ public class AuthenticationFilter extends OncePerRequestFilter {
this.authenticationManagerResolver = authenticationManagerResolver;
}
/**
* Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
* authentication success. The default action is not to save the
* {@link SecurityContext}.
* @param securityContextRepository the {@link SecurityContextRepository} to use.
* Cannot be null.
*/
public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
this.securityContextRepository = securityContextRepository;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
@ -173,6 +189,7 @@ public class AuthenticationFilter extends OncePerRequestFilter {
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
this.securityContextRepository.saveContext(context, request, response);
this.successHandler.onAuthenticationSuccess(request, response, chain, authentication);
}

View File

@ -40,6 +40,8 @@ import org.springframework.security.web.WebAttributes;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.context.NullSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.filter.GenericFilterBean;
@ -104,6 +106,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
private RequestMatcher requiresAuthenticationRequestMatcher = new PreAuthenticatedProcessingRequestMatcher();
private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
/**
* Check whether all required properties have been set.
*/
@ -210,6 +214,7 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authResult);
SecurityContextHolder.setContext(context);
this.securityContextRepository.saveContext(context, request, response);
if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass()));
}
@ -242,6 +247,18 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
this.eventPublisher = anApplicationEventPublisher;
}
/**
* Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
* authentication success. The default action is not to save the
* {@link SecurityContext}.
* @param securityContextRepository the {@link SecurityContextRepository} to use.
* Cannot be null.
*/
public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
this.securityContextRepository = securityContextRepository;
}
/**
* @param authenticationDetailsSource The AuthenticationDetailsSource to use
*/

View File

@ -36,6 +36,8 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.RememberMeServices;
import org.springframework.security.web.context.NullSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.util.Assert;
import org.springframework.web.filter.GenericFilterBean;
@ -73,6 +75,8 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
private RememberMeServices rememberMeServices;
private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
public RememberMeAuthenticationFilter(AuthenticationManager authenticationManager,
RememberMeServices rememberMeServices) {
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
@ -114,6 +118,7 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
onSuccessfulAuthentication(request, response, rememberMeAuth);
this.logger.debug(LogMessage.of(() -> "SecurityContextHolder populated with remember-me token: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'"));
this.securityContextRepository.saveContext(context, request, response);
if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(
SecurityContextHolder.getContext().getAuthentication(), this.getClass()));
@ -179,4 +184,16 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
this.successHandler = successHandler;
}
/**
* Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
* authentication success. The default action is not to save the
* {@link SecurityContext}.
* @param securityContextRepository the {@link SecurityContextRepository} to use.
* Cannot be null.
*/
public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
this.securityContextRepository = securityContextRepository;
}
}

View File

@ -36,6 +36,8 @@ import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.NullRememberMeServices;
import org.springframework.security.web.authentication.RememberMeServices;
import org.springframework.security.web.context.NullSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
@ -103,6 +105,8 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
private BasicAuthenticationConverter authenticationConverter = new BasicAuthenticationConverter();
private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
/**
* Creates an instance which will authenticate against the supplied
* {@code AuthenticationManager} and which will ignore failed authentication attempts,
@ -131,6 +135,18 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
this.authenticationEntryPoint = authenticationEntryPoint;
}
/**
* Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
* authentication success. The default action is not to save the
* {@link SecurityContext}.
* @param securityContextRepository the {@link SecurityContextRepository} to use.
* Cannot be null.
*/
public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
this.securityContextRepository = securityContextRepository;
}
@Override
public void afterPropertiesSet() {
Assert.notNull(this.authenticationManager, "An AuthenticationManager is required");
@ -161,6 +177,7 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult));
}
this.rememberMeServices.loginSuccess(request, response, authResult);
this.securityContextRepository.saveContext(context, request, response);
onSuccessfulAuthentication(request, response, authResult);
}
}

View File

@ -49,6 +49,8 @@ import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.security.core.userdetails.cache.NullUserCache;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.context.NullSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.GenericFilterBean;
@ -106,6 +108,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
private boolean createAuthenticatedToken = false;
private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
@Override
public void afterPropertiesSet() {
Assert.notNull(this.userDetailsService, "A UserDetailsService is required");
@ -192,6 +196,7 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
this.securityContextRepository.saveContext(context, request, response);
chain.doFilter(request, response);
}
@ -271,6 +276,18 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
this.createAuthenticatedToken = createAuthenticatedToken;
}
/**
* Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
* authentication success. The default action is not to save the
* {@link SecurityContext}.
* @param securityContextRepository the {@link SecurityContextRepository} to use.
* Cannot be null.
*/
public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
this.securityContextRepository = securityContextRepository;
}
private class DigestData {
private final String username;

View File

@ -25,6 +25,7 @@ import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
@ -38,7 +39,9 @@ import org.springframework.security.authentication.AuthenticationManagerResolver
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.RequestMatcher;
import static org.assertj.core.api.Assertions.assertThat;
@ -256,4 +259,36 @@ public class AuthenticationFilterTests {
assertThat(session.getId()).isNotEqualTo(sessionId);
}
@Test
public void filterWhenSuccessfulAuthenticationThenNoSessionCreated() throws Exception {
Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER");
given(this.authenticationConverter.convert(any())).willReturn(authentication);
given(this.authenticationManager.authenticate(any())).willReturn(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = new MockFilterChain();
AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager,
this.authenticationConverter);
filter.doFilter(request, response, chain);
assertThat(request.getSession(false)).isNull();
}
@Test
public void filterWhenCustomSecurityContextRepositoryAndSuccessfulAuthenticationRepositoryUsed() throws Exception {
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
ArgumentCaptor<SecurityContext> securityContextArg = ArgumentCaptor.forClass(SecurityContext.class);
Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER");
given(this.authenticationConverter.convert(any())).willReturn(authentication);
given(this.authenticationManager.authenticate(any())).willReturn(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = new MockFilterChain();
AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager,
this.authenticationConverter);
filter.setSecurityContextRepository(securityContextRepository);
filter.doFilter(request, response, chain);
verify(securityContextRepository).saveContext(securityContextArg.capture(), eq(request), eq(response));
assertThat(securityContextArg.getValue().getAuthentication()).isEqualTo(authentication);
}
}

View File

@ -23,6 +23,7 @@ import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.springframework.mock.web.MockFilterChain;
@ -34,17 +35,20 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.web.WebAttributes;
import org.springframework.security.web.authentication.ForwardAuthenticationFailureHandler;
import org.springframework.security.web.authentication.ForwardAuthenticationSuccessHandler;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -210,6 +214,31 @@ public class AbstractPreAuthenticatedProcessingFilterTests {
assertThat(response.getForwardedUrl()).isEqualTo("/forwardUrl");
}
@Test
public void securityContextRepository() throws Exception {
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
Object currentPrincipal = "currentUser";
TestingAuthenticationToken authRequest = new TestingAuthenticationToken(currentPrincipal, "something",
"ROLE_USER");
SecurityContextHolder.getContext().setAuthentication(authRequest);
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain chain = new MockFilterChain();
ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter();
filter.setSecurityContextRepository(securityContextRepository);
filter.setAuthenticationSuccessHandler(new ForwardAuthenticationSuccessHandler("/forwardUrl"));
filter.setCheckForPrincipalChanges(true);
filter.principal = "newUser";
AuthenticationManager am = mock(AuthenticationManager.class);
given(am.authenticate(any())).willReturn(authRequest);
filter.setAuthenticationManager(am);
filter.afterPropertiesSet();
filter.doFilter(request, response, chain);
ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response));
assertThat(contextArg.getValue().getAuthentication().getPrincipal()).isEqualTo(authRequest.getName());
}
@Test
public void callsAuthenticationFailureHandlerOnFailedAuthentication() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();

View File

@ -36,10 +36,12 @@ import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.NullRememberMeServices;
import org.springframework.security.web.authentication.RememberMeServices;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
import org.springframework.security.web.context.SecurityContextRepository;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -152,6 +154,23 @@ public class RememberMeAuthenticationFilterTests {
verifyZeroInteractions(fc);
}
@Test
public void securityContextRepositoryInvokedIfSet() throws Exception {
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
AuthenticationManager am = mock(AuthenticationManager.class);
given(am.authenticate(this.remembered)).willReturn(this.remembered);
RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter(am,
new MockRememberMeServices(this.remembered));
filter.setAuthenticationSuccessHandler(new SimpleUrlAuthenticationSuccessHandler("/target"));
filter.setSecurityContextRepository(securityContextRepository);
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain fc = mock(FilterChain.class);
request.setRequestURI("x");
filter.doFilter(request, response, fc);
verify(securityContextRepository).saveContext(any(), eq(request), eq(response));
}
private class MockRememberMeServices implements RememberMeServices {
private Authentication authToReturn;

View File

@ -27,6 +27,7 @@ import org.apache.commons.codec.binary.Base64;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
@ -36,8 +37,10 @@ import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.WebAuthenticationDetails;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.web.util.WebUtils;
import static org.assertj.core.api.Assertions.assertThat;
@ -364,4 +367,25 @@ public class BasicAuthenticationFilterTests {
assertThat(response.getStatus()).isEqualTo(401);
}
@Test
public void requestWhenSecurityContextRepository() throws Exception {
ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
this.filter.setSecurityContextRepository(securityContextRepository);
String token = "rod:koala";
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes())));
request.setServletPath("/some_file.html");
MockHttpServletResponse response = new MockHttpServletResponse();
// Test
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
FilterChain chain = mock(FilterChain.class);
this.filter.doFilter(request, response, chain);
verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("rod");
verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response));
assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo("rod");
}
}

View File

@ -29,6 +29,7 @@ import org.apache.commons.codec.digest.DigestUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
@ -40,10 +41,12 @@ import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.cache.NullUserCache;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.util.StringUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -389,4 +392,25 @@ public class DigestAuthenticationFilterTests {
assertThat(existingAuthentication).isSameAs(existingContext.getAuthentication());
}
@Test
public void testSecurityContextRepository() throws Exception {
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
String responseDigest = DigestAuthUtils.generateDigest(false, USERNAME, REALM, PASSWORD, "GET", REQUEST_URI,
QOP, NONCE, NC, CNONCE);
this.request.addHeader("Authorization",
createAuthorizationHeader(USERNAME, REALM, NONCE, REQUEST_URI, responseDigest, QOP, NC, CNONCE));
this.filter.setSecurityContextRepository(securityContextRepository);
this.filter.setCreateAuthenticatedToken(true);
MockHttpServletResponse response = executeFilterInContainerSimulator(this.filter, this.request, true);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
assertThat(((UserDetails) SecurityContextHolder.getContext().getAuthentication().getPrincipal()).getUsername())
.isEqualTo(USERNAME);
assertThat(SecurityContextHolder.getContext().getAuthentication().isAuthenticated()).isTrue();
assertThat(SecurityContextHolder.getContext().getAuthentication().getAuthorities())
.isEqualTo(AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"));
verify(securityContextRepository).saveContext(contextArg.capture(), eq(this.request), eq(response));
assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo(USERNAME);
}
}