Use SecurityContextHolderStrategy for Async Requests

Issue gh-11060
Issue gh-11061
This commit is contained in:
Josh Cummings 2022-06-21 16:45:10 -06:00
parent 5086409dcf
commit a218d3e140
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
7 changed files with 152 additions and 11 deletions

View File

@ -33,6 +33,8 @@ import org.springframework.security.config.annotation.authentication.configurati
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.annotation.web.configurers.DefaultLoginPageConfigurer;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;
import static org.springframework.security.config.Customizer.withDefaults;
@ -58,6 +60,9 @@ class HttpSecurityConfiguration {
private ApplicationContext context;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
@Autowired
void setObjectPostProcessor(ObjectPostProcessor<Object> objectPostProcessor) {
this.objectPostProcessor = objectPostProcessor;
@ -77,6 +82,11 @@ class HttpSecurityConfiguration {
this.context = context;
}
@Autowired(required = false)
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
@Bean(HTTPSECURITY_BEAN_NAME)
@Scope("prototype")
HttpSecurity httpSecurity() throws Exception {
@ -86,10 +96,12 @@ class HttpSecurityConfiguration {
this.objectPostProcessor, passwordEncoder);
authenticationBuilder.parentAuthenticationManager(authenticationManager());
HttpSecurity http = new HttpSecurity(this.objectPostProcessor, authenticationBuilder, createSharedObjects());
WebAsyncManagerIntegrationFilter webAsyncManagerIntegrationFilter = new WebAsyncManagerIntegrationFilter();
webAsyncManagerIntegrationFilter.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
// @formatter:off
http
.csrf(withDefaults())
.addFilter(new WebAsyncManagerIntegrationFilter())
.addFilter(webAsyncManagerIntegrationFilter)
.exceptionHandling(withDefaults())
.headers(withDefaults())
.sessionManagement(withDefaults())

View File

@ -587,6 +587,7 @@ class HttpConfigurationBuilder {
boolean asyncSupported = ClassUtils.hasMethod(ServletRequest.class, "startAsync");
if (asyncSupported) {
this.webAsyncManagerFilter = new RootBeanDefinition(WebAsyncManagerIntegrationFilter.class);
this.webAsyncManagerFilter.getPropertyValues().add("securityContextHolderStrategy", this.holderStrategyRef);
}
}

View File

@ -35,11 +35,13 @@ import org.springframework.core.io.support.SpringFactoriesLoader;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
@ -54,6 +56,8 @@ import org.springframework.web.bind.annotation.RestController;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.verify;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
@ -134,6 +138,22 @@ public class HttpSecurityConfigurationTests {
// @formatter:on
}
@Test
public void asyncDispatchWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
this.spring.register(DefaultWithFilterChainConfig.class, SecurityContextChangedListenerConfig.class,
NameController.class).autowire();
// @formatter:off
MockHttpServletRequestBuilder requestWithBob = get("/name").with(user("Bob"));
MvcResult mvcResult = this.mockMvc.perform(requestWithBob)
.andExpect(request().asyncStarted())
.andReturn();
this.mockMvc.perform(asyncDispatch(mvcResult))
.andExpect(status().isOk())
.andExpect(content().string("Bob"));
// @formatter:on
verify(this.spring.getContext().getBean(SecurityContextHolderStrategy.class), atLeastOnce()).getContext();
}
@Test
public void getWhenDefaultFilterChainBeanThenAnonymousPermitted() throws Exception {
this.spring.register(AuthorizeRequestsConfig.class, UserDetailsConfig.class, BaseController.class).autowire();
@ -243,8 +263,8 @@ public class HttpSecurityConfigurationTests {
static class NameController {
@GetMapping("/name")
Callable<String> name() {
return () -> SecurityContextHolder.getContext().getAuthentication().getName();
Callable<String> name(Authentication authentication) {
return () -> authentication.getName();
}
}

View File

@ -27,6 +27,7 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import javax.security.auth.Subject;
@ -127,12 +128,15 @@ import static org.mockito.Mockito.verify;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.x509;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.asyncDispatch;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
/**
@ -762,6 +766,21 @@ public class MiscHttpConfigTests {
// @formatter:on
}
@Test
public void asyncDispatchWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
this.spring.configLocations(xml("WithSecurityContextHolderStrategy")).autowire();
// @formatter:off
MockHttpServletRequestBuilder requestWithBob = get("/name").with(user("Bob"));
MvcResult mvcResult = this.mvc.perform(requestWithBob)
.andExpect(request().asyncStarted())
.andReturn();
this.mvc.perform(asyncDispatch(mvcResult))
.andExpect(status().isOk())
.andExpect(content().string("Bob"));
// @formatter:on
verify(this.spring.getContext().getBean(SecurityContextHolderStrategy.class), atLeastOnce()).getContext();
}
/**
* SEC-1893
*/
@ -905,6 +924,11 @@ public class MiscHttpConfigTests {
return authentication.getDetails().getClass().getName();
}
@GetMapping("/name")
Callable<String> name(Authentication authentication) {
return () -> authentication.getName();
}
}
@RestController

View File

@ -0,0 +1,51 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright 2002-2018 the original author or authors.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ https://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<b:beans xmlns:b="http://www.springframework.org/schema/beans"
xmlns:mvc="http://www.springframework.org/schema/mvc"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://www.springframework.org/schema/security"
xsi:schemaLocation="
http://www.springframework.org/schema/security
https://www.springframework.org/schema/security/spring-security.xsd
http://www.springframework.org/schema/beans
https://www.springframework.org/schema/beans/spring-beans.xsd
http://www.springframework.org/schema/mvc
https://www.springframework.org/schema/mvc/spring-mvc.xsd">
<http auto-config="true" security-context-holder-strategy-ref="ref">
<intercept-url pattern="/**" access="authenticated"/>
</http>
<b:bean id="ref" class="org.mockito.Mockito" factory-method="spy">
<b:constructor-arg>
<b:bean class="org.springframework.security.config.MockSecurityContextHolderStrategy"/>
</b:constructor-arg>
</b:bean>
<mvc:annotation-driven>
<mvc:argument-resolvers>
<b:bean class="org.springframework.security.web.method.annotation.AuthenticationPrincipalArgumentResolver">
<b:property name="securityContextHolderStrategy" ref="ref"/>
</b:bean>
</mvc:argument-resolvers>
</mvc:annotation-driven>
<b:bean class="org.springframework.security.config.http.MiscHttpConfigTests.AuthenticationController"/>
<b:import resource="userservice.xml"/>
</b:beans>

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import java.util.concurrent.Callable;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.async.CallableProcessingInterceptor;
@ -43,6 +44,9 @@ public final class SecurityContextCallableProcessingInterceptor implements Calla
private volatile SecurityContext securityContext;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
/**
* Create a new {@link SecurityContextCallableProcessingInterceptor} that uses the
* {@link SecurityContext} from the {@link SecurityContextHolder} at the time
@ -67,18 +71,29 @@ public final class SecurityContextCallableProcessingInterceptor implements Calla
@Override
public <T> void beforeConcurrentHandling(NativeWebRequest request, Callable<T> task) {
if (this.securityContext == null) {
setSecurityContext(SecurityContextHolder.getContext());
setSecurityContext(this.securityContextHolderStrategy.getContext());
}
}
@Override
public <T> void preProcess(NativeWebRequest request, Callable<T> task) {
SecurityContextHolder.setContext(this.securityContext);
this.securityContextHolderStrategy.setContext(this.securityContext);
}
@Override
public <T> void postProcess(NativeWebRequest request, Callable<T> task, Object concurrentResult) {
SecurityContextHolder.clearContext();
this.securityContextHolderStrategy.clearContext();
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
private void setSecurityContext(SecurityContext securityContext) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,6 +25,9 @@ import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
import org.springframework.web.context.request.async.WebAsyncManager;
import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.filter.OncePerRequestFilter;
@ -42,6 +45,9 @@ public final class WebAsyncManagerIntegrationFilter extends OncePerRequestFilter
private static final Object CALLABLE_INTERCEPTOR_KEY = new Object();
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
@ -49,10 +55,22 @@ public final class WebAsyncManagerIntegrationFilter extends OncePerRequestFilter
SecurityContextCallableProcessingInterceptor securityProcessingInterceptor = (SecurityContextCallableProcessingInterceptor) asyncManager
.getCallableInterceptor(CALLABLE_INTERCEPTOR_KEY);
if (securityProcessingInterceptor == null) {
asyncManager.registerCallableInterceptor(CALLABLE_INTERCEPTOR_KEY,
new SecurityContextCallableProcessingInterceptor());
SecurityContextCallableProcessingInterceptor interceptor = new SecurityContextCallableProcessingInterceptor();
interceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
asyncManager.registerCallableInterceptor(CALLABLE_INTERCEPTOR_KEY, interceptor);
}
filterChain.doFilter(request, response);
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
}