Add SecurityContextHolderStrategy Test Support

Issue gh-11061
Issue gh-11444
This commit is contained in:
Josh Cummings 2022-06-21 17:00:05 -06:00
parent fa0086d3b0
commit f86992a0af
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
6 changed files with 122 additions and 15 deletions

View File

@ -0,0 +1,45 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.test.context;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
public final class TestSecurityContextHolderStrategyAdapter implements SecurityContextHolderStrategy {
@Override
public void clearContext() {
TestSecurityContextHolder.clearContext();
}
@Override
public SecurityContext getContext() {
return TestSecurityContextHolder.getContext();
}
@Override
public void setContext(SecurityContext context) {
TestSecurityContextHolder.setContext(context);
}
@Override
public SecurityContext createEmptyContext() {
return SecurityContextHolder.createEmptyContext();
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2014 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,12 +18,14 @@ package org.springframework.security.test.context.support;
import java.util.List; import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
/** /**
* A {@link WithAnonymousUserSecurityContextFactory} that runs with an * A {@link WithAnonymousUserSecurityContextFactory} that runs with an
@ -35,13 +37,21 @@ import org.springframework.security.core.context.SecurityContextHolder;
*/ */
final class WithAnonymousUserSecurityContextFactory implements WithSecurityContextFactory<WithAnonymousUser> { final class WithAnonymousUserSecurityContextFactory implements WithSecurityContextFactory<WithAnonymousUser> {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
@Override @Override
public SecurityContext createSecurityContext(WithAnonymousUser withUser) { public SecurityContext createSecurityContext(WithAnonymousUser withUser) {
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"); List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS");
Authentication authentication = new AnonymousAuthenticationToken("key", "anonymous", authorities); Authentication authentication = new AnonymousAuthenticationToken("key", "anonymous", authorities);
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authentication); context.setAuthentication(authentication);
return context; return context;
} }
@Autowired(required = false)
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
} }

View File

@ -20,12 +20,14 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.User;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -39,6 +41,9 @@ import org.springframework.util.StringUtils;
*/ */
final class WithMockUserSecurityContextFactory implements WithSecurityContextFactory<WithMockUser> { final class WithMockUserSecurityContextFactory implements WithSecurityContextFactory<WithMockUser> {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
@Override @Override
public SecurityContext createSecurityContext(WithMockUser withUser) { public SecurityContext createSecurityContext(WithMockUser withUser) {
String username = StringUtils.hasLength(withUser.username()) ? withUser.username() : withUser.value(); String username = StringUtils.hasLength(withUser.username()) ? withUser.username() : withUser.value();
@ -60,9 +65,14 @@ final class WithMockUserSecurityContextFactory implements WithSecurityContextFac
User principal = new User(username, withUser.password(), true, true, true, true, grantedAuthorities); User principal = new User(username, withUser.password(), true, true, true, true, grantedAuthorities);
Authentication authentication = UsernamePasswordAuthenticationToken.authenticated(principal, Authentication authentication = UsernamePasswordAuthenticationToken.authenticated(principal,
principal.getPassword(), principal.getAuthorities()); principal.getPassword(), principal.getAuthorities());
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authentication); context.setAuthentication(authentication);
return context; return context;
} }
@Autowired(required = false)
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -21,12 +21,16 @@ import java.lang.reflect.AnnotatedElement;
import java.util.function.Supplier; import java.util.function.Supplier;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.core.GenericTypeResolver; import org.springframework.core.GenericTypeResolver;
import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.security.test.context.TestSecurityContextHolder;
import org.springframework.security.test.context.TestSecurityContextHolderStrategyAdapter;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors; import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors;
import org.springframework.test.context.TestContext; import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestContextAnnotationUtils; import org.springframework.test.context.TestContextAnnotationUtils;
@ -53,6 +57,19 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut
static final String SECURITY_CONTEXT_ATTR_NAME = WithSecurityContextTestExecutionListener.class.getName() static final String SECURITY_CONTEXT_ATTR_NAME = WithSecurityContextTestExecutionListener.class.getName()
.concat(".SECURITY_CONTEXT"); .concat(".SECURITY_CONTEXT");
static final SecurityContextHolderStrategy DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY = new TestSecurityContextHolderStrategyAdapter();
Converter<TestContext, SecurityContextHolderStrategy> securityContextHolderStrategyConverter = (testContext) -> {
if (!testContext.hasApplicationContext()) {
return DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY;
}
ApplicationContext context = testContext.getApplicationContext();
if (context.getBeanNamesForType(SecurityContextHolderStrategy.class).length == 0) {
return DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY;
}
return context.getBean(SecurityContextHolderStrategy.class);
};
/** /**
* Sets up the {@link SecurityContext} for each test method. First the specific method * Sets up the {@link SecurityContext} for each test method. First the specific method
* is inspected for a {@link WithSecurityContext} or {@link Annotation} that has * is inspected for a {@link WithSecurityContext} or {@link Annotation} that has
@ -70,7 +87,7 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut
} }
Supplier<SecurityContext> supplier = testSecurityContext.getSecurityContextSupplier(); Supplier<SecurityContext> supplier = testSecurityContext.getSecurityContextSupplier();
if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) { if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) {
TestSecurityContextHolder.setContext(supplier.get()); this.securityContextHolderStrategyConverter.convert(testContext).setContext(supplier.get());
} }
else { else {
testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, supplier); testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, supplier);
@ -86,7 +103,7 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut
Supplier<SecurityContext> supplier = (Supplier<SecurityContext>) testContext Supplier<SecurityContext> supplier = (Supplier<SecurityContext>) testContext
.removeAttribute(SECURITY_CONTEXT_ATTR_NAME); .removeAttribute(SECURITY_CONTEXT_ATTR_NAME);
if (supplier != null) { if (supplier != null) {
TestSecurityContextHolder.setContext(supplier.get()); this.securityContextHolderStrategyConverter.convert(testContext).setContext(supplier.get());
} }
} }
@ -166,7 +183,7 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut
*/ */
@Override @Override
public void afterTestMethod(TestContext testContext) { public void afterTestMethod(TestContext testContext) {
TestSecurityContextHolder.clearContext(); this.securityContextHolderStrategyConverter.convert(testContext).clearContext();
} }
/** /**

View File

@ -24,6 +24,7 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.core.userdetails.ReactiveUserDetailsService;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UserDetailsService;
@ -45,6 +46,9 @@ final class WithUserDetailsSecurityContextFactory implements WithSecurityContext
private static final boolean reactorPresent = ClassUtils.isPresent("reactor.core.publisher.Mono", private static final boolean reactorPresent = ClassUtils.isPresent("reactor.core.publisher.Mono",
WithUserDetailsSecurityContextFactory.class.getClassLoader()); WithUserDetailsSecurityContextFactory.class.getClassLoader());
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private BeanFactory beans; private BeanFactory beans;
@Autowired @Autowired
@ -61,11 +65,16 @@ final class WithUserDetailsSecurityContextFactory implements WithSecurityContext
UserDetails principal = userDetailsService.loadUserByUsername(username); UserDetails principal = userDetailsService.loadUserByUsername(username);
Authentication authentication = UsernamePasswordAuthenticationToken.authenticated(principal, Authentication authentication = UsernamePasswordAuthenticationToken.authenticated(principal,
principal.getPassword(), principal.getAuthorities()); principal.getPassword(), principal.getAuthorities());
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authentication); context.setAuthentication(authentication);
return context; return context;
} }
@Autowired(required = false)
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
private UserDetailsService findUserDetailsService(String beanName) { private UserDetailsService findUserDetailsService(String beanName) {
if (reactorPresent) { if (reactorPresent) {
UserDetailsService reactive = findAndAdaptReactiveUserDetailsService(beanName); UserDetailsService reactive = findAndAdaptReactiveUserDetailsService(beanName);

View File

@ -55,6 +55,7 @@ import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
@ -85,6 +86,7 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut
import org.springframework.security.oauth2.server.resource.authentication.JwtGrantedAuthoritiesConverter; import org.springframework.security.oauth2.server.resource.authentication.JwtGrantedAuthoritiesConverter;
import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionAuthenticatedPrincipal; import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionAuthenticatedPrincipal;
import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.security.test.context.TestSecurityContextHolder;
import org.springframework.security.test.context.TestSecurityContextHolderStrategyAdapter;
import org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers; import org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers;
import org.springframework.security.test.web.support.WebTestUtils; import org.springframework.security.test.web.support.WebTestUtils;
import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpRequestResponseHolder;
@ -115,6 +117,8 @@ import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandl
*/ */
public final class SecurityMockMvcRequestPostProcessors { public final class SecurityMockMvcRequestPostProcessors {
private static final SecurityContextHolderStrategy DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY = new TestSecurityContextHolderStrategyAdapter();
private SecurityMockMvcRequestPostProcessors() { private SecurityMockMvcRequestPostProcessors() {
} }
@ -455,6 +459,18 @@ public final class SecurityMockMvcRequestPostProcessors {
return new OAuth2ClientRequestPostProcessor(registrationId); return new OAuth2ClientRequestPostProcessor(registrationId);
} }
private static SecurityContextHolderStrategy getSecurityContextHolderStrategy(HttpServletRequest request) {
WebApplicationContext context = WebApplicationContextUtils
.findWebApplicationContext(request.getServletContext());
if (context == null) {
return DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY;
}
if (context.getBeanNamesForType(SecurityContextHolderStrategy.class).length == 0) {
return DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY;
}
return context.getBean(SecurityContextHolderStrategy.class);
}
/** /**
* Populates the X509Certificate instances onto the request * Populates the X509Certificate instances onto the request
*/ */
@ -710,7 +726,7 @@ public final class SecurityMockMvcRequestPostProcessors {
* @param request the {@link HttpServletRequest} to use * @param request the {@link HttpServletRequest} to use
*/ */
final void save(Authentication authentication, HttpServletRequest request) { final void save(Authentication authentication, HttpServletRequest request) {
SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); SecurityContext securityContext = getSecurityContextHolderStrategy(request).createEmptyContext();
securityContext.setAuthentication(authentication); securityContext.setAuthentication(authentication);
save(securityContext, request); save(securityContext, request);
} }
@ -790,8 +806,6 @@ public final class SecurityMockMvcRequestPostProcessors {
private static final class TestSecurityContextHolderPostProcessor extends SecurityContextRequestPostProcessorSupport private static final class TestSecurityContextHolderPostProcessor extends SecurityContextRequestPostProcessorSupport
implements RequestPostProcessor { implements RequestPostProcessor {
private SecurityContext EMPTY = SecurityContextHolder.createEmptyContext();
@Override @Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
// TestSecurityContextHolder is only a default value // TestSecurityContextHolder is only a default value
@ -799,8 +813,10 @@ public final class SecurityMockMvcRequestPostProcessors {
if (existingContext != null) { if (existingContext != null) {
return request; return request;
} }
SecurityContext context = TestSecurityContextHolder.getContext(); SecurityContextHolderStrategy strategy = getSecurityContextHolderStrategy(request);
if (!this.EMPTY.equals(context)) { SecurityContext empty = strategy.createEmptyContext();
SecurityContext context = strategy.getContext();
if (!empty.equals(context)) {
save(context, request); save(context, request);
} }
return request; return request;
@ -851,7 +867,7 @@ public final class SecurityMockMvcRequestPostProcessors {
@Override @Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = getSecurityContextHolderStrategy(request).createEmptyContext();
context.setAuthentication(this.authentication); context.setAuthentication(this.authentication);
save(this.authentication, request); save(this.authentication, request);
return request; return request;
@ -869,7 +885,7 @@ public final class SecurityMockMvcRequestPostProcessors {
*/ */
private static final class UserDetailsRequestPostProcessor implements RequestPostProcessor { private static final class UserDetailsRequestPostProcessor implements RequestPostProcessor {
private final RequestPostProcessor delegate; private final AuthenticationRequestPostProcessor delegate;
UserDetailsRequestPostProcessor(UserDetails user) { UserDetailsRequestPostProcessor(UserDetails user) {
Authentication token = UsernamePasswordAuthenticationToken.authenticated(user, user.getPassword(), Authentication token = UsernamePasswordAuthenticationToken.authenticated(user, user.getPassword(),