diff --git a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java index 839f648713..6e683368ee 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.web.context.NullSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.AnyRequestMatcher; @@ -67,6 +68,9 @@ import org.springframework.web.filter.OncePerRequestFilter; */ public class AuthenticationFilter extends OncePerRequestFilter { + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE; private AuthenticationConverter authenticationConverter; @@ -151,6 +155,17 @@ public class AuthenticationFilter extends OncePerRequestFilter { this.securityContextRepository = securityContextRepository; } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { @@ -180,15 +195,15 @@ public class AuthenticationFilter extends OncePerRequestFilter { private void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) throws IOException, ServletException { - SecurityContextHolder.clearContext(); + this.securityContextHolderStrategy.clearContext(); this.failureHandler.onAuthenticationFailure(request, response, failed); } private void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, Authentication authentication) throws IOException, ServletException { - SecurityContext context = SecurityContextHolder.createEmptyContext(); + SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); context.setAuthentication(authentication); - SecurityContextHolder.setContext(context); + this.securityContextHolderStrategy.setContext(context); this.securityContextRepository.saveContext(context, request, response); this.successHandler.onAuthenticationSuccess(request, response, chain, authentication); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java index 3e0a34955e..37eadafb38 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,8 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -128,6 +130,25 @@ public class AuthenticationFilterTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull(); } + @Test + public void filterWhenCustomSecurityContextHolderStrategyThenUses() throws Exception { + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE"); + given(this.authenticationConverter.convert(any())).willReturn(authentication); + given(this.authenticationManager.authenticate(any())).willReturn(authentication); + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager, + this.authenticationConverter); + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.createEmptyContext()).willReturn(new SecurityContextImpl()); + filter.setSecurityContextHolderStrategy(strategy); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + verify(strategy).setContext(any()); + } + @Test public void filterWhenAuthenticationManagerResolverDefaultsAndAuthenticationSuccessThenContinues() throws Exception {