From 27de315e5e6b71a828ddc64668f6c09d563f96c4 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 21 Jun 2022 16:45:10 -0600 Subject: [PATCH] Use SecurityContextHolderStrategy for Async Requests Issue gh-11060 Issue gh-11061 --- .../HttpSecurityConfiguration.java | 14 ++++- .../config/http/HttpConfigurationBuilder.java | 1 + .../HttpSecurityConfigurationTests.java | 26 ++++++++-- .../config/http/MiscHttpConfigTests.java | 24 +++++++++ ...ests-WithSecurityContextHolderStrategy.xml | 51 +++++++++++++++++++ ...yContextCallableProcessingInterceptor.java | 23 +++++++-- .../WebAsyncManagerIntegrationFilter.java | 24 +++++++-- 7 files changed, 152 insertions(+), 11 deletions(-) create mode 100644 config/src/test/resources/org/springframework/security/config/http/MiscHttpConfigTests-WithSecurityContextHolderStrategy.xml diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfiguration.java index 468ba74bf5..8b12f9bea6 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfiguration.java @@ -33,6 +33,8 @@ import org.springframework.security.config.annotation.authentication.configurati import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.config.annotation.web.configurers.DefaultLoginPageConfigurer; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter; import static org.springframework.security.config.Customizer.withDefaults; @@ -58,6 +60,9 @@ class HttpSecurityConfiguration { private ApplicationContext context; + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + @Autowired void setObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { this.objectPostProcessor = objectPostProcessor; @@ -77,6 +82,11 @@ class HttpSecurityConfiguration { this.context = context; } + @Autowired(required = false) + void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + @Bean(HTTPSECURITY_BEAN_NAME) @Scope("prototype") HttpSecurity httpSecurity() throws Exception { @@ -86,10 +96,12 @@ class HttpSecurityConfiguration { this.objectPostProcessor, passwordEncoder); authenticationBuilder.parentAuthenticationManager(authenticationManager()); HttpSecurity http = new HttpSecurity(this.objectPostProcessor, authenticationBuilder, createSharedObjects()); + WebAsyncManagerIntegrationFilter webAsyncManagerIntegrationFilter = new WebAsyncManagerIntegrationFilter(); + webAsyncManagerIntegrationFilter.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); // @formatter:off http .csrf(withDefaults()) - .addFilter(new WebAsyncManagerIntegrationFilter()) + .addFilter(webAsyncManagerIntegrationFilter) .exceptionHandling(withDefaults()) .headers(withDefaults()) .sessionManagement(withDefaults()) diff --git a/config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java b/config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java index e38cd5dd8b..940eba4cf5 100644 --- a/config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java +++ b/config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java @@ -588,6 +588,7 @@ class HttpConfigurationBuilder { boolean asyncSupported = ClassUtils.hasMethod(ServletRequest.class, "startAsync"); if (asyncSupported) { this.webAsyncManagerFilter = new RootBeanDefinition(WebAsyncManagerIntegrationFilter.class); + this.webAsyncManagerFilter.getPropertyValues().add("securityContextHolderStrategy", this.holderStrategyRef); } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfigurationTests.java index 030f8f2a7e..d04c1bcf2f 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfigurationTests.java @@ -36,11 +36,13 @@ import org.springframework.core.io.support.SpringFactoriesLoader; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; -import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; @@ -55,6 +57,8 @@ import org.springframework.web.bind.annotation.RestController; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; @@ -135,6 +139,22 @@ public class HttpSecurityConfigurationTests { // @formatter:on } + @Test + public void asyncDispatchWhenCustomSecurityContextHolderStrategyThenUses() throws Exception { + this.spring.register(DefaultWithFilterChainConfig.class, SecurityContextChangedListenerConfig.class, + NameController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithBob = get("/name").with(user("Bob")); + MvcResult mvcResult = this.mockMvc.perform(requestWithBob) + .andExpect(request().asyncStarted()) + .andReturn(); + this.mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(content().string("Bob")); + // @formatter:on + verify(this.spring.getContext().getBean(SecurityContextHolderStrategy.class), atLeastOnce()).getContext(); + } + @Test public void getWhenDefaultFilterChainBeanThenAnonymousPermitted() throws Exception { this.spring.register(AuthorizeRequestsConfig.class, UserDetailsConfig.class, BaseController.class).autowire(); @@ -244,8 +264,8 @@ public class HttpSecurityConfigurationTests { static class NameController { @GetMapping("/name") - Callable name() { - return () -> SecurityContextHolder.getContext().getAuthentication().getName(); + Callable name(Authentication authentication) { + return () -> authentication.getName(); } } diff --git a/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java b/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java index d65b278539..f9dcba9582 100644 --- a/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java @@ -27,6 +27,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.concurrent.Callable; import java.util.stream.Collectors; import javax.security.auth.Subject; @@ -129,12 +130,15 @@ import static org.mockito.Mockito.verify; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.x509; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.asyncDispatch; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** @@ -766,6 +770,21 @@ public class MiscHttpConfigTests { // @formatter:on } + @Test + public void asyncDispatchWhenCustomSecurityContextHolderStrategyThenUses() throws Exception { + this.spring.configLocations(xml("WithSecurityContextHolderStrategy")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithBob = get("/name").with(user("Bob")); + MvcResult mvcResult = this.mvc.perform(requestWithBob) + .andExpect(request().asyncStarted()) + .andReturn(); + this.mvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(content().string("Bob")); + // @formatter:on + verify(this.spring.getContext().getBean(SecurityContextHolderStrategy.class), atLeastOnce()).getContext(); + } + /** * SEC-1893 */ @@ -909,6 +928,11 @@ public class MiscHttpConfigTests { return authentication.getDetails().getClass().getName(); } + @GetMapping("/name") + Callable name(Authentication authentication) { + return () -> authentication.getName(); + } + } @RestController diff --git a/config/src/test/resources/org/springframework/security/config/http/MiscHttpConfigTests-WithSecurityContextHolderStrategy.xml b/config/src/test/resources/org/springframework/security/config/http/MiscHttpConfigTests-WithSecurityContextHolderStrategy.xml new file mode 100644 index 0000000000..b89d3b380b --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/MiscHttpConfigTests-WithSecurityContextHolderStrategy.xml @@ -0,0 +1,51 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/main/java/org/springframework/security/web/context/request/async/SecurityContextCallableProcessingInterceptor.java b/web/src/main/java/org/springframework/security/web/context/request/async/SecurityContextCallableProcessingInterceptor.java index bb54c2cc9d..c0e8993584 100644 --- a/web/src/main/java/org/springframework/security/web/context/request/async/SecurityContextCallableProcessingInterceptor.java +++ b/web/src/main/java/org/springframework/security/web/context/request/async/SecurityContextCallableProcessingInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import java.util.concurrent.Callable; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.async.CallableProcessingInterceptor; @@ -44,6 +45,9 @@ public final class SecurityContextCallableProcessingInterceptor extends Callable private volatile SecurityContext securityContext; + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + /** * Create a new {@link SecurityContextCallableProcessingInterceptor} that uses the * {@link SecurityContext} from the {@link SecurityContextHolder} at the time @@ -68,18 +72,29 @@ public final class SecurityContextCallableProcessingInterceptor extends Callable @Override public void beforeConcurrentHandling(NativeWebRequest request, Callable task) { if (this.securityContext == null) { - setSecurityContext(SecurityContextHolder.getContext()); + setSecurityContext(this.securityContextHolderStrategy.getContext()); } } @Override public void preProcess(NativeWebRequest request, Callable task) { - SecurityContextHolder.setContext(this.securityContext); + this.securityContextHolderStrategy.setContext(this.securityContext); } @Override public void postProcess(NativeWebRequest request, Callable task, Object concurrentResult) { - SecurityContextHolder.clearContext(); + this.securityContextHolderStrategy.clearContext(); + } + + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; } private void setSecurityContext(SecurityContext securityContext) { diff --git a/web/src/main/java/org/springframework/security/web/context/request/async/WebAsyncManagerIntegrationFilter.java b/web/src/main/java/org/springframework/security/web/context/request/async/WebAsyncManagerIntegrationFilter.java index dbe0c65f4c..d4fffa6f6b 100644 --- a/web/src/main/java/org/springframework/security/web/context/request/async/WebAsyncManagerIntegrationFilter.java +++ b/web/src/main/java/org/springframework/security/web/context/request/async/WebAsyncManagerIntegrationFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,9 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.util.Assert; import org.springframework.web.context.request.async.WebAsyncManager; import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.filter.OncePerRequestFilter; @@ -42,6 +45,9 @@ public final class WebAsyncManagerIntegrationFilter extends OncePerRequestFilter private static final Object CALLABLE_INTERCEPTOR_KEY = new Object(); + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { @@ -49,10 +55,22 @@ public final class WebAsyncManagerIntegrationFilter extends OncePerRequestFilter SecurityContextCallableProcessingInterceptor securityProcessingInterceptor = (SecurityContextCallableProcessingInterceptor) asyncManager .getCallableInterceptor(CALLABLE_INTERCEPTOR_KEY); if (securityProcessingInterceptor == null) { - asyncManager.registerCallableInterceptor(CALLABLE_INTERCEPTOR_KEY, - new SecurityContextCallableProcessingInterceptor()); + SecurityContextCallableProcessingInterceptor interceptor = new SecurityContextCallableProcessingInterceptor(); + interceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); + asyncManager.registerCallableInterceptor(CALLABLE_INTERCEPTOR_KEY, interceptor); } filterChain.doFilter(request, response); } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + }