Use SecurityContextHolderStrategy for AuthenticationFilter

Issue gh-11060
This commit is contained in:
Josh Cummings 2022-06-21 16:47:05 -06:00
parent 74167d62b1
commit 0fee05d023
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
2 changed files with 41 additions and 5 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.context.NullSecurityContextRepository; import org.springframework.security.web.context.NullSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.security.web.util.matcher.AnyRequestMatcher;
@ -67,6 +68,9 @@ import org.springframework.web.filter.OncePerRequestFilter;
*/ */
public class AuthenticationFilter extends OncePerRequestFilter { public class AuthenticationFilter extends OncePerRequestFilter {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE; private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE;
private AuthenticationConverter authenticationConverter; private AuthenticationConverter authenticationConverter;
@ -151,6 +155,17 @@ public class AuthenticationFilter extends OncePerRequestFilter {
this.securityContextRepository = securityContextRepository; this.securityContextRepository = securityContextRepository;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
@ -180,15 +195,15 @@ public class AuthenticationFilter extends OncePerRequestFilter {
private void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, private void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
AuthenticationException failed) throws IOException, ServletException { AuthenticationException failed) throws IOException, ServletException {
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
this.failureHandler.onAuthenticationFailure(request, response, failed); this.failureHandler.onAuthenticationFailure(request, response, failed);
} }
private void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, private void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
Authentication authentication) throws IOException, ServletException { Authentication authentication) throws IOException, ServletException {
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authentication); context.setAuthentication(authentication);
SecurityContextHolder.setContext(context); this.securityContextHolderStrategy.setContext(context);
this.securityContextRepository.saveContext(context, request, response); this.securityContextRepository.saveContext(context, request, response);
this.successHandler.onAuthenticationSuccess(request, response, chain, authentication); this.successHandler.onAuthenticationSuccess(request, response, chain, authentication);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -41,6 +41,8 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
@ -129,6 +131,25 @@ public class AuthenticationFilterTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
} }
@Test
public void filterWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE");
given(this.authenticationConverter.convert(any())).willReturn(authentication);
given(this.authenticationManager.authenticate(any())).willReturn(authentication);
AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager,
this.authenticationConverter);
SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
given(strategy.createEmptyContext()).willReturn(new SecurityContextImpl());
filter.setSecurityContextHolderStrategy(strategy);
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class);
filter.doFilter(request, response, chain);
verify(this.authenticationManager).authenticate(any(Authentication.class));
verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
verify(strategy).setContext(any());
}
@Test @Test
public void filterWhenAuthenticationManagerResolverDefaultsAndAuthenticationSuccessThenContinues() public void filterWhenAuthenticationManagerResolverDefaultsAndAuthenticationSuccessThenContinues()
throws Exception { throws Exception {