Add SecurityContextHolderStrategy to Default Components

Issue gh-11060
This commit is contained in:
Josh Cummings 2022-05-26 14:14:10 -06:00
parent 01513ab17e
commit 31e25b115e
23 changed files with 362 additions and 52 deletions

View File

@ -47,6 +47,7 @@ import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.SpringSecurityMessageSource; import org.springframework.security.core.SpringSecurityMessageSource;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@ -111,6 +112,9 @@ public abstract class AbstractSecurityInterceptor
protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor();
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private ApplicationEventPublisher eventPublisher; private ApplicationEventPublisher eventPublisher;
private AccessDecisionManager accessDecisionManager; private AccessDecisionManager accessDecisionManager;
@ -196,7 +200,7 @@ public abstract class AbstractSecurityInterceptor
publishEvent(new PublicInvocationEvent(object)); publishEvent(new PublicInvocationEvent(object));
return null; // no further work post-invocation return null; // no further work post-invocation
} }
if (SecurityContextHolder.getContext().getAuthentication() == null) { if (this.securityContextHolderStrategy.getContext().getAuthentication() == null) {
credentialsNotFound(this.messages.getMessage("AbstractSecurityInterceptor.authenticationNotFound", credentialsNotFound(this.messages.getMessage("AbstractSecurityInterceptor.authenticationNotFound",
"An Authentication object was not found in the SecurityContext"), object, attributes); "An Authentication object was not found in the SecurityContext"), object, attributes);
} }
@ -216,10 +220,10 @@ public abstract class AbstractSecurityInterceptor
// Attempt to run as a different user // Attempt to run as a different user
Authentication runAs = this.runAsManager.buildRunAs(authenticated, object, attributes); Authentication runAs = this.runAsManager.buildRunAs(authenticated, object, attributes);
if (runAs != null) { if (runAs != null) {
SecurityContext origCtx = SecurityContextHolder.getContext(); SecurityContext origCtx = this.securityContextHolderStrategy.getContext();
SecurityContext newCtx = SecurityContextHolder.createEmptyContext(); SecurityContext newCtx = this.securityContextHolderStrategy.createEmptyContext();
newCtx.setAuthentication(runAs); newCtx.setAuthentication(runAs);
SecurityContextHolder.setContext(newCtx); this.securityContextHolderStrategy.setContext(newCtx);
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Switched to RunAs authentication %s", runAs)); this.logger.debug(LogMessage.format("Switched to RunAs authentication %s", runAs));
@ -229,7 +233,7 @@ public abstract class AbstractSecurityInterceptor
} }
this.logger.trace("Did not switch RunAs authentication since RunAsManager returned null"); this.logger.trace("Did not switch RunAs authentication since RunAsManager returned null");
// no further work post-invocation // no further work post-invocation
return new InterceptorStatusToken(SecurityContextHolder.getContext(), false, attributes, object); return new InterceptorStatusToken(this.securityContextHolderStrategy.getContext(), false, attributes, object);
} }
@ -260,7 +264,7 @@ public abstract class AbstractSecurityInterceptor
*/ */
protected void finallyInvocation(InterceptorStatusToken token) { protected void finallyInvocation(InterceptorStatusToken token) {
if (token != null && token.isContextHolderRefreshRequired()) { if (token != null && token.isContextHolderRefreshRequired()) {
SecurityContextHolder.setContext(token.getSecurityContext()); this.securityContextHolderStrategy.setContext(token.getSecurityContext());
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.of( this.logger.debug(LogMessage.of(
() -> "Reverted to original authentication " + token.getSecurityContext().getAuthentication())); () -> "Reverted to original authentication " + token.getSecurityContext().getAuthentication()));
@ -305,7 +309,7 @@ public abstract class AbstractSecurityInterceptor
* @return an authenticated <tt>Authentication</tt> object. * @return an authenticated <tt>Authentication</tt> object.
*/ */
private Authentication authenticateIfRequired() { private Authentication authenticateIfRequired() {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
if (authentication.isAuthenticated() && !this.alwaysReauthenticate) { if (authentication.isAuthenticated() && !this.alwaysReauthenticate) {
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace(LogMessage.format("Did not re-authenticate %s before authorizing", authentication)); this.logger.trace(LogMessage.format("Did not re-authenticate %s before authorizing", authentication));
@ -317,9 +321,9 @@ public abstract class AbstractSecurityInterceptor
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Re-authenticated %s before authorizing", authentication)); this.logger.debug(LogMessage.format("Re-authenticated %s before authorizing", authentication));
} }
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authentication); context.setAuthentication(authentication);
SecurityContextHolder.setContext(context); this.securityContextHolderStrategy.setContext(context);
return authentication; return authentication;
} }
@ -378,6 +382,17 @@ public abstract class AbstractSecurityInterceptor
public abstract SecurityMetadataSource obtainSecurityMetadataSource(); public abstract SecurityMetadataSource obtainSecurityMetadataSource();
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
public void setAccessDecisionManager(AccessDecisionManager accessDecisionManager) { public void setAccessDecisionManager(AccessDecisionManager accessDecisionManager) {
this.accessDecisionManager = accessDecisionManager; this.accessDecisionManager = accessDecisionManager;
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -15,6 +15,7 @@
<property name="excludes" value="io.spring.javaformat.checkstyle.check.SpringHeaderCheck" /> <property name="excludes" value="io.spring.javaformat.checkstyle.check.SpringHeaderCheck" />
<property name="avoidStaticImportExcludes" value="org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.*" /> <property name="avoidStaticImportExcludes" value="org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.*" />
<property name="avoidStaticImportExcludes" value="org.springframework.security.test.web.servlet.response.SecurityMockMvcResultHandlers.*" /> <property name="avoidStaticImportExcludes" value="org.springframework.security.test.web.servlet.response.SecurityMockMvcResultHandlers.*" />
<property name="avoidStaticImportExcludes" value="org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.*" />
</module> </module>
<module name="com.puppycrawl.tools.checkstyle.TreeWalker"> <module name="com.puppycrawl.tools.checkstyle.TreeWalker">
<module name="com.puppycrawl.tools.checkstyle.checks.regexp.RegexpSinglelineJavaCheck"> <module name="com.puppycrawl.tools.checkstyle.checks.regexp.RegexpSinglelineJavaCheck">

View File

@ -34,6 +34,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.firewall.DefaultRequestRejectedHandler; import org.springframework.security.web.firewall.DefaultRequestRejectedHandler;
import org.springframework.security.web.firewall.FirewalledRequest; import org.springframework.security.web.firewall.FirewalledRequest;
import org.springframework.security.web.firewall.HttpFirewall; import org.springframework.security.web.firewall.HttpFirewall;
@ -146,6 +147,9 @@ public class FilterChainProxy extends GenericFilterBean {
private static final String FILTER_APPLIED = FilterChainProxy.class.getName().concat(".APPLIED"); private static final String FILTER_APPLIED = FilterChainProxy.class.getName().concat(".APPLIED");
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private List<SecurityFilterChain> filterChains; private List<SecurityFilterChain> filterChains;
private FilterChainValidator filterChainValidator = new NullFilterChainValidator(); private FilterChainValidator filterChainValidator = new NullFilterChainValidator();
@ -186,7 +190,7 @@ public class FilterChainProxy extends GenericFilterBean {
this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex); this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex);
} }
finally { finally {
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
request.removeAttribute(FILTER_APPLIED); request.removeAttribute(FILTER_APPLIED);
} }
} }
@ -247,6 +251,17 @@ public class FilterChainProxy extends GenericFilterBean {
return Collections.unmodifiableList(this.filterChains); return Collections.unmodifiableList(this.filterChains);
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/** /**
* Used (internally) to specify a validation strategy for the filters in each * Used (internally) to specify a validation strategy for the filters in each
* configured chain. * configured chain.

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2004-2021 the original author or authors. * Copyright 2004-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -38,6 +38,7 @@ import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.SpringSecurityMessageSource; import org.springframework.security.core.SpringSecurityMessageSource;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.RequestCache;
@ -82,6 +83,9 @@ import org.springframework.web.filter.GenericFilterBean;
*/ */
public class ExceptionTranslationFilter extends GenericFilterBean implements MessageSourceAware { public class ExceptionTranslationFilter extends GenericFilterBean implements MessageSourceAware {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
private AuthenticationEntryPoint authenticationEntryPoint; private AuthenticationEntryPoint authenticationEntryPoint;
@ -183,7 +187,7 @@ public class ExceptionTranslationFilter extends GenericFilterBean implements Mes
private void handleAccessDeniedException(HttpServletRequest request, HttpServletResponse response, private void handleAccessDeniedException(HttpServletRequest request, HttpServletResponse response,
FilterChain chain, AccessDeniedException exception) throws ServletException, IOException { FilterChain chain, AccessDeniedException exception) throws ServletException, IOException {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
boolean isAnonymous = this.authenticationTrustResolver.isAnonymous(authentication); boolean isAnonymous = this.authenticationTrustResolver.isAnonymous(authentication);
if (isAnonymous || this.authenticationTrustResolver.isRememberMe(authentication)) { if (isAnonymous || this.authenticationTrustResolver.isRememberMe(authentication)) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
@ -209,8 +213,8 @@ public class ExceptionTranslationFilter extends GenericFilterBean implements Mes
AuthenticationException reason) throws ServletException, IOException { AuthenticationException reason) throws ServletException, IOException {
// SEC-112: Clear the SecurityContextHolder's Authentication, as the // SEC-112: Clear the SecurityContextHolder's Authentication, as the
// existing Authentication is no longer considered valid // existing Authentication is no longer considered valid
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
SecurityContextHolder.setContext(context); this.securityContextHolderStrategy.setContext(context);
this.requestCache.saveRequest(request, response); this.requestCache.saveRequest(request, response);
this.authenticationEntryPoint.commence(request, response, reason); this.authenticationEntryPoint.commence(request, response, reason);
} }
@ -239,6 +243,17 @@ public class ExceptionTranslationFilter extends GenericFilterBean implements Mes
this.messages = new MessageSourceAccessor(messageSource); this.messages = new MessageSourceAccessor(messageSource);
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/** /**
* Default implementation of <code>ThrowableAnalyzer</code> which is capable of also * Default implementation of <code>ThrowableAnalyzer</code> which is capable of also
* unwrapping <code>ServletException</code>s. * unwrapping <code>ServletException</code>s.

View File

@ -34,6 +34,7 @@ import org.springframework.security.authorization.event.AuthorizationDeniedEvent
import org.springframework.security.authorization.event.AuthorizationGrantedEvent; import org.springframework.security.authorization.event.AuthorizationGrantedEvent;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
@ -46,6 +47,9 @@ import org.springframework.web.filter.OncePerRequestFilter;
*/ */
public class AuthorizationFilter extends OncePerRequestFilter { public class AuthorizationFilter extends OncePerRequestFilter {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private final AuthorizationManager<HttpServletRequest> authorizationManager; private final AuthorizationManager<HttpServletRequest> authorizationManager;
private AuthorizationEventPublisher eventPublisher = AuthorizationFilter::noPublish; private AuthorizationEventPublisher eventPublisher = AuthorizationFilter::noPublish;
@ -73,8 +77,19 @@ public class AuthorizationFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
private Authentication getAuthentication() { private Authentication getAuthentication() {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
if (authentication == null) { if (authentication == null) {
throw new AuthenticationCredentialsNotFoundException( throw new AuthenticationCredentialsNotFoundException(
"An Authentication object was not found in the SecurityContext"); "An Authentication object was not found in the SecurityContext");

View File

@ -40,6 +40,7 @@ import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.SpringSecurityMessageSource; import org.springframework.security.core.SpringSecurityMessageSource;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.authentication.session.NullAuthenticatedSessionStrategy; import org.springframework.security.web.authentication.session.NullAuthenticatedSessionStrategy;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.context.NullSecurityContextRepository; import org.springframework.security.web.context.NullSecurityContextRepository;
@ -114,6 +115,9 @@ import org.springframework.web.filter.GenericFilterBean;
public abstract class AbstractAuthenticationProcessingFilter extends GenericFilterBean public abstract class AbstractAuthenticationProcessingFilter extends GenericFilterBean
implements ApplicationEventPublisherAware, MessageSourceAware { implements ApplicationEventPublisherAware, MessageSourceAware {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
protected ApplicationEventPublisher eventPublisher; protected ApplicationEventPublisher eventPublisher;
protected AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource(); protected AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
@ -315,9 +319,9 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
*/ */
protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
Authentication authResult) throws IOException, ServletException { Authentication authResult) throws IOException, ServletException {
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authResult); context.setAuthentication(authResult);
SecurityContextHolder.setContext(context); this.securityContextHolderStrategy.setContext(context);
this.securityContextRepository.saveContext(context, request, response); this.securityContextRepository.saveContext(context, request, response);
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult)); this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult));
@ -342,7 +346,7 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
*/ */
protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
AuthenticationException failed) throws IOException, ServletException { AuthenticationException failed) throws IOException, ServletException {
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
this.logger.trace("Failed to process authentication request", failed); this.logger.trace("Failed to process authentication request", failed);
this.logger.trace("Cleared SecurityContextHolder"); this.logger.trace("Cleared SecurityContextHolder");
this.logger.trace("Handling authentication failure"); this.logger.trace("Handling authentication failure");
@ -452,6 +456,17 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
this.securityContextRepository = securityContextRepository; this.securityContextRepository = securityContextRepository;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
protected AuthenticationSuccessHandler getSuccessHandler() { protected AuthenticationSuccessHandler getSuccessHandler() {
return this.successHandler; return this.successHandler;
} }

View File

@ -34,6 +34,7 @@ import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.filter.GenericFilterBean; import org.springframework.web.filter.GenericFilterBean;
@ -46,6 +47,9 @@ import org.springframework.web.filter.GenericFilterBean;
*/ */
public class AnonymousAuthenticationFilter extends GenericFilterBean implements InitializingBean { public class AnonymousAuthenticationFilter extends GenericFilterBean implements InitializingBean {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource(); private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
private String key; private String key;
@ -87,14 +91,14 @@ public class AnonymousAuthenticationFilter extends GenericFilterBean implements
@Override @Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
if (SecurityContextHolder.getContext().getAuthentication() == null) { if (this.securityContextHolderStrategy.getContext().getAuthentication() == null) {
Authentication authentication = createAuthentication((HttpServletRequest) req); Authentication authentication = createAuthentication((HttpServletRequest) req);
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authentication); context.setAuthentication(authentication);
SecurityContextHolder.setContext(context); this.securityContextHolderStrategy.setContext(context);
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace(LogMessage.of(() -> "Set SecurityContextHolder to " this.logger.trace(LogMessage.of(() -> "Set SecurityContextHolder to "
+ SecurityContextHolder.getContext().getAuthentication())); + this.securityContextHolderStrategy.getContext().getAuthentication()));
} }
else { else {
this.logger.debug("Set SecurityContextHolder to anonymous SecurityContext"); this.logger.debug("Set SecurityContextHolder to anonymous SecurityContext");
@ -103,7 +107,7 @@ public class AnonymousAuthenticationFilter extends GenericFilterBean implements
else { else {
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace(LogMessage.of(() -> "Did not set SecurityContextHolder since already authenticated " this.logger.trace(LogMessage.of(() -> "Did not set SecurityContextHolder since already authenticated "
+ SecurityContextHolder.getContext().getAuthentication())); + this.securityContextHolderStrategy.getContext().getAuthentication()));
} }
} }
chain.doFilter(req, res); chain.doFilter(req, res);
@ -122,6 +126,17 @@ public class AnonymousAuthenticationFilter extends GenericFilterBean implements
this.authenticationDetailsSource = authenticationDetailsSource; this.authenticationDetailsSource = authenticationDetailsSource;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
public Object getPrincipal() { public Object getPrincipal() {
return this.principal; return this.principal;
} }

View File

@ -28,6 +28,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.UrlUtils;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
@ -52,6 +53,9 @@ import org.springframework.web.filter.GenericFilterBean;
*/ */
public class LogoutFilter extends GenericFilterBean { public class LogoutFilter extends GenericFilterBean {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private RequestMatcher logoutRequestMatcher; private RequestMatcher logoutRequestMatcher;
private final LogoutHandler handler; private final LogoutHandler handler;
@ -92,7 +96,7 @@ public class LogoutFilter extends GenericFilterBean {
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
if (requiresLogout(request, response)) { if (requiresLogout(request, response)) {
Authentication auth = SecurityContextHolder.getContext().getAuthentication(); Authentication auth = this.securityContextHolderStrategy.getContext().getAuthentication();
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Logging out [%s]", auth)); this.logger.debug(LogMessage.format("Logging out [%s]", auth));
} }
@ -119,6 +123,17 @@ public class LogoutFilter extends GenericFilterBean {
return false; return false;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
public void setLogoutRequestMatcher(RequestMatcher logoutRequestMatcher) { public void setLogoutRequestMatcher(RequestMatcher logoutRequestMatcher) {
Assert.notNull(logoutRequestMatcher, "logoutRequestMatcher cannot be null"); Assert.notNull(logoutRequestMatcher, "logoutRequestMatcher cannot be null");
this.logoutRequestMatcher = logoutRequestMatcher; this.logoutRequestMatcher = logoutRequestMatcher;

View File

@ -33,6 +33,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.NullRememberMeServices; import org.springframework.security.web.authentication.NullRememberMeServices;
import org.springframework.security.web.authentication.RememberMeServices; import org.springframework.security.web.authentication.RememberMeServices;
@ -93,6 +94,9 @@ import org.springframework.web.filter.OncePerRequestFilter;
*/ */
public class BasicAuthenticationFilter extends OncePerRequestFilter { public class BasicAuthenticationFilter extends OncePerRequestFilter {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private AuthenticationEntryPoint authenticationEntryPoint; private AuthenticationEntryPoint authenticationEntryPoint;
private AuthenticationManager authenticationManager; private AuthenticationManager authenticationManager;
@ -170,9 +174,9 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
this.logger.trace(LogMessage.format("Found username '%s' in Basic Authorization header", username)); this.logger.trace(LogMessage.format("Found username '%s' in Basic Authorization header", username));
if (authenticationIsRequired(username)) { if (authenticationIsRequired(username)) {
Authentication authResult = this.authenticationManager.authenticate(authRequest); Authentication authResult = this.authenticationManager.authenticate(authRequest);
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authResult); context.setAuthentication(authResult);
SecurityContextHolder.setContext(context); this.securityContextHolderStrategy.setContext(context);
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult)); this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult));
} }
@ -182,7 +186,7 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
} }
} }
catch (AuthenticationException ex) { catch (AuthenticationException ex) {
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
this.logger.debug("Failed to process authentication request", ex); this.logger.debug("Failed to process authentication request", ex);
this.rememberMeServices.loginFail(request, response); this.rememberMeServices.loginFail(request, response);
onUnsuccessfulAuthentication(request, response, ex); onUnsuccessfulAuthentication(request, response, ex);
@ -201,7 +205,7 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
private boolean authenticationIsRequired(String username) { private boolean authenticationIsRequired(String username) {
// Only reauthenticate if username doesn't match SecurityContextHolder and user // Only reauthenticate if username doesn't match SecurityContextHolder and user
// isn't authenticated (see SEC-53) // isn't authenticated (see SEC-53)
Authentication existingAuth = SecurityContextHolder.getContext().getAuthentication(); Authentication existingAuth = this.securityContextHolderStrategy.getContext().getAuthentication();
if (existingAuth == null || !existingAuth.isAuthenticated()) { if (existingAuth == null || !existingAuth.isAuthenticated()) {
return true; return true;
} }
@ -242,6 +246,17 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
return this.ignoreFailure; return this.ignoreFailure;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
public void setAuthenticationDetailsSource( public void setAuthenticationDetailsSource(
AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) { AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
this.authenticationConverter.setAuthenticationDetailsSource(authenticationDetailsSource); this.authenticationConverter.setAuthenticationDetailsSource(authenticationDetailsSource);

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -90,11 +90,14 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
protected final Log logger = LogFactory.getLog(this.getClass()); protected final Log logger = LogFactory.getLog(this.getClass());
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
/** /**
* SecurityContext instance used to check for equality with default (unauthenticated) * SecurityContext instance used to check for equality with default (unauthenticated)
* content * content
*/ */
private final Object contextObject = SecurityContextHolder.createEmptyContext(); private Object contextObject = this.securityContextHolderStrategy.createEmptyContext();
private boolean allowSessionCreation = true; private boolean allowSessionCreation = true;
@ -126,6 +129,7 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
if (response != null) { if (response != null) {
SaveToSessionResponseWrapper wrappedResponse = new SaveToSessionResponseWrapper(response, request, SaveToSessionResponseWrapper wrappedResponse = new SaveToSessionResponseWrapper(response, request,
httpSession != null, context); httpSession != null, context);
wrappedResponse.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
requestResponseHolder.setResponse(wrappedResponse); requestResponseHolder.setResponse(wrappedResponse);
requestResponseHolder.setRequest(new SaveToSessionRequestWrapper(request, wrappedResponse)); requestResponseHolder.setRequest(new SaveToSessionRequestWrapper(request, wrappedResponse));
} }
@ -201,7 +205,7 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
* @return a new SecurityContext instance. Never null. * @return a new SecurityContext instance. Never null.
*/ */
protected SecurityContext generateNewContext() { protected SecurityContext generateNewContext() {
return SecurityContextHolder.createEmptyContext(); return this.securityContextHolderStrategy.createEmptyContext();
} }
/** /**
@ -237,6 +241,17 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
this.springSecurityContextKey = springSecurityContextKey; this.springSecurityContextKey = springSecurityContextKey;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) {
this.securityContextHolderStrategy = strategy;
this.contextObject = this.securityContextHolderStrategy.createEmptyContext();
}
private boolean isTransient(Object object) { private boolean isTransient(Object object) {
if (object == null) { if (object == null) {
return false; return false;

View File

@ -21,6 +21,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
/** /**
* @author Luke Taylor * @author Luke Taylor
@ -28,6 +29,9 @@ import org.springframework.security.core.context.SecurityContextHolder;
*/ */
public final class NullSecurityContextRepository implements SecurityContextRepository { public final class NullSecurityContextRepository implements SecurityContextRepository {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
@Override @Override
public boolean containsContext(HttpServletRequest request) { public boolean containsContext(HttpServletRequest request) {
return false; return false;
@ -35,11 +39,21 @@ public final class NullSecurityContextRepository implements SecurityContextRepos
@Override @Override
public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) { public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) {
return SecurityContextHolder.createEmptyContext(); return this.securityContextHolderStrategy.createEmptyContext();
} }
@Override @Override
public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) { public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) {
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) {
this.securityContextHolderStrategy = strategy;
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2016 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -21,7 +21,9 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.util.OnCommittedResponseWrapper; import org.springframework.security.web.util.OnCommittedResponseWrapper;
import org.springframework.util.Assert;
/** /**
* Base class for response wrappers which encapsulate the logic for storing a security * Base class for response wrappers which encapsulate the logic for storing a security
@ -46,6 +48,9 @@ import org.springframework.security.web.util.OnCommittedResponseWrapper;
@Deprecated @Deprecated
public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends OnCommittedResponseWrapper { public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends OnCommittedResponseWrapper {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private boolean contextSaved = false; private boolean contextSaved = false;
// See SEC-1052 // See SEC-1052
@ -62,6 +67,17 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends OnCommit
this.disableUrlRewriting = disableUrlRewriting; this.disableUrlRewriting = disableUrlRewriting;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/** /**
* Invoke this method to disable automatic saving of the {@link SecurityContext} when * Invoke this method to disable automatic saving of the {@link SecurityContext} when
* the {@link HttpServletResponse} is committed. This can be useful in the event that * the {@link HttpServletResponse} is committed. This can be useful in the event that
@ -85,7 +101,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends OnCommit
*/ */
@Override @Override
protected void onResponseCommitted() { protected void onResponseCommitted() {
saveContext(SecurityContextHolder.getContext()); saveContext(this.securityContextHolderStrategy.getContext());
this.contextSaved = true; this.contextSaved = true;
} }

View File

@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
@ -44,6 +45,9 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
private final SecurityContextRepository securityContextRepository; private final SecurityContextRepository securityContextRepository;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private boolean shouldNotFilterErrorDispatch; private boolean shouldNotFilterErrorDispatch;
/** /**
@ -60,11 +64,11 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
throws ServletException, IOException { throws ServletException, IOException {
SecurityContext securityContext = this.securityContextRepository.loadContext(request).get(); SecurityContext securityContext = this.securityContextRepository.loadContext(request).get();
try { try {
SecurityContextHolder.setContext(securityContext); this.securityContextHolderStrategy.setContext(securityContext);
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
} }
finally { finally {
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
} }
} }
@ -73,6 +77,17 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
return this.shouldNotFilterErrorDispatch; return this.shouldNotFilterErrorDispatch;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/** /**
* Disables {@link SecurityContextHolderFilter} for error dispatch. * Disables {@link SecurityContextHolderFilter} for error dispatch.
* @param shouldNotFilterErrorDispatch if the Filter should be disabled for error * @param shouldNotFilterErrorDispatch if the Filter should be disabled for error

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2016 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,6 +29,8 @@ import javax.servlet.http.HttpSession;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
import org.springframework.web.filter.GenericFilterBean; import org.springframework.web.filter.GenericFilterBean;
/** /**
@ -66,6 +68,9 @@ public class SecurityContextPersistenceFilter extends GenericFilterBean {
private SecurityContextRepository repo; private SecurityContextRepository repo;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private boolean forceEagerSessionCreation = false; private boolean forceEagerSessionCreation = false;
public SecurityContextPersistenceFilter() { public SecurityContextPersistenceFilter() {
@ -99,7 +104,7 @@ public class SecurityContextPersistenceFilter extends GenericFilterBean {
HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response); HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response);
SecurityContext contextBeforeChainExecution = this.repo.loadContext(holder); SecurityContext contextBeforeChainExecution = this.repo.loadContext(holder);
try { try {
SecurityContextHolder.setContext(contextBeforeChainExecution); this.securityContextHolderStrategy.setContext(contextBeforeChainExecution);
if (contextBeforeChainExecution.getAuthentication() == null) { if (contextBeforeChainExecution.getAuthentication() == null) {
logger.debug("Set SecurityContextHolder to empty SecurityContext"); logger.debug("Set SecurityContextHolder to empty SecurityContext");
} }
@ -112,9 +117,9 @@ public class SecurityContextPersistenceFilter extends GenericFilterBean {
chain.doFilter(holder.getRequest(), holder.getResponse()); chain.doFilter(holder.getRequest(), holder.getResponse());
} }
finally { finally {
SecurityContext contextAfterChainExecution = SecurityContextHolder.getContext(); SecurityContext contextAfterChainExecution = this.securityContextHolderStrategy.getContext();
// Crucial removal of SecurityContextHolder contents before anything else. // Crucial removal of SecurityContextHolder contents before anything else.
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
this.repo.saveContext(contextAfterChainExecution, holder.getRequest(), holder.getResponse()); this.repo.saveContext(contextAfterChainExecution, holder.getRequest(), holder.getResponse());
request.removeAttribute(FILTER_APPLIED); request.removeAttribute(FILTER_APPLIED);
this.logger.debug("Cleared SecurityContextHolder to complete request"); this.logger.debug("Cleared SecurityContextHolder to complete request");
@ -125,4 +130,15 @@ public class SecurityContextPersistenceFilter extends GenericFilterBean {
this.forceEagerSessionCreation = forceEagerSessionCreation; this.forceEagerSessionCreation = forceEagerSessionCreation;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
} }

View File

@ -28,7 +28,9 @@ import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.annotation.AuthenticationPrincipal; import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.bind.support.WebDataBinderFactory; import org.springframework.web.bind.support.WebDataBinderFactory;
@ -88,6 +90,9 @@ import org.springframework.web.method.support.ModelAndViewContainer;
*/ */
public final class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver { public final class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private ExpressionParser parser = new SpelExpressionParser(); private ExpressionParser parser = new SpelExpressionParser();
private BeanResolver beanResolver; private BeanResolver beanResolver;
@ -100,7 +105,7 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
@Override @Override
public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer, public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer,
NativeWebRequest webRequest, WebDataBinderFactory binderFactory) { NativeWebRequest webRequest, WebDataBinderFactory binderFactory) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
if (authentication == null) { if (authentication == null) {
return null; return null;
} }
@ -132,6 +137,17 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
this.beanResolver = beanResolver; this.beanResolver = beanResolver;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/** /**
* Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}.
* @param annotationClass the class of the {@link Annotation} to find on the * @param annotationClass the class of the {@link Annotation} to find on the

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2016 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,6 +30,7 @@ import org.springframework.security.authentication.AuthenticationTrustResolver;
import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler; import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.authentication.session.SessionAuthenticationException;
@ -53,6 +54,9 @@ public class SessionManagementFilter extends GenericFilterBean {
static final String FILTER_APPLIED = "__spring_security_session_mgmt_filter_applied"; static final String FILTER_APPLIED = "__spring_security_session_mgmt_filter_applied";
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private final SecurityContextRepository securityContextRepository; private final SecurityContextRepository securityContextRepository;
private SessionAuthenticationStrategy sessionAuthenticationStrategy; private SessionAuthenticationStrategy sessionAuthenticationStrategy;
@ -89,7 +93,7 @@ public class SessionManagementFilter extends GenericFilterBean {
} }
request.setAttribute(FILTER_APPLIED, Boolean.TRUE); request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
if (!this.securityContextRepository.containsContext(request)) { if (!this.securityContextRepository.containsContext(request)) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
if (authentication != null && !this.trustResolver.isAnonymous(authentication)) { if (authentication != null && !this.trustResolver.isAnonymous(authentication)) {
// The user has been authenticated during the current request, so call the // The user has been authenticated during the current request, so call the
// session strategy // session strategy
@ -99,14 +103,15 @@ public class SessionManagementFilter extends GenericFilterBean {
catch (SessionAuthenticationException ex) { catch (SessionAuthenticationException ex) {
// The session strategy can reject the authentication // The session strategy can reject the authentication
this.logger.debug("SessionAuthenticationStrategy rejected the authentication object", ex); this.logger.debug("SessionAuthenticationStrategy rejected the authentication object", ex);
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
this.failureHandler.onAuthenticationFailure(request, response, ex); this.failureHandler.onAuthenticationFailure(request, response, ex);
return; return;
} }
// Eagerly save the security context to make it available for any possible // Eagerly save the security context to make it available for any possible
// re-entrant requests which may occur before the current request // re-entrant requests which may occur before the current request
// completes. SEC-1396. // completes. SEC-1396.
this.securityContextRepository.saveContext(SecurityContextHolder.getContext(), request, response); this.securityContextRepository.saveContext(this.securityContextHolderStrategy.getContext(), request,
response);
} }
else { else {
// No security context or authentication present. Check for a session // No security context or authentication present. Check for a session
@ -160,4 +165,15 @@ public class SessionManagementFilter extends GenericFilterBean {
this.trustResolver = trustResolver; this.trustResolver = trustResolver;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2016 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -36,6 +36,7 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.firewall.FirewalledRequest; import org.springframework.security.web.firewall.FirewalledRequest;
import org.springframework.security.web.firewall.HttpFirewall; import org.springframework.security.web.firewall.HttpFirewall;
import org.springframework.security.web.firewall.RequestRejectedException; import org.springframework.security.web.firewall.RequestRejectedException;
@ -198,6 +199,15 @@ public class FilterChainProxyTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
} }
@Test
public void doFilterWhenCustomSecurityContextHolderStrategyClearsSecurityContext() throws Exception {
SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
this.fcp.setSecurityContextHolderStrategy(strategy);
given(this.matcher.matches(any(HttpServletRequest.class))).willReturn(true);
this.fcp.doFilter(this.request, this.response, this.chain);
verify(strategy).clearContext();
}
@Test @Test
public void doFilterClearsSecurityContextHolderWithException() throws Exception { public void doFilterClearsSecurityContextHolderWithException() throws Exception {
given(this.matcher.matches(any(HttpServletRequest.class))).willReturn(true); given(this.matcher.matches(any(HttpServletRequest.class))).willReturn(true);

View File

@ -38,6 +38,7 @@ import org.springframework.security.authorization.AuthorizationManager;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.web.util.WebUtils; import org.springframework.web.util.WebUtils;
@ -72,9 +73,9 @@ public class AuthorizationFilterTests {
AuthorizationFilter filter = new AuthorizationFilter(mockAuthorizationManager); AuthorizationFilter filter = new AuthorizationFilter(mockAuthorizationManager);
TestingAuthenticationToken authenticationToken = new TestingAuthenticationToken("user", "password"); TestingAuthenticationToken authenticationToken = new TestingAuthenticationToken("user", "password");
SecurityContext securityContext = new SecurityContextImpl(); SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
securityContext.setAuthentication(authenticationToken); given(strategy.getContext()).willReturn(new SecurityContextImpl(authenticationToken));
SecurityContextHolder.setContext(securityContext); filter.setSecurityContextHolderStrategy(strategy);
MockHttpServletRequest mockRequest = new MockHttpServletRequest(null, "/path"); MockHttpServletRequest mockRequest = new MockHttpServletRequest(null, "/path");
MockHttpServletResponse mockResponse = new MockHttpServletResponse(); MockHttpServletResponse mockResponse = new MockHttpServletResponse();
@ -88,6 +89,7 @@ public class AuthorizationFilterTests {
assertThat(authentication.get()).isEqualTo(authenticationToken); assertThat(authentication.get()).isEqualTo(authenticationToken);
verify(mockFilterChain).doFilter(mockRequest, mockResponse); verify(mockFilterChain).doFilter(mockRequest, mockResponse);
verify(strategy).getContext();
} }
@Test @Test

View File

@ -35,11 +35,17 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
/** /**
* Tests {@link AnonymousAuthenticationFilter}. * Tests {@link AnonymousAuthenticationFilter}.
@ -74,16 +80,19 @@ public class AnonymousAuthenticationFilterTests {
public void testOperationWhenAuthenticationExistsInContextHolder() throws Exception { public void testOperationWhenAuthenticationExistsInContextHolder() throws Exception {
// Put an Authentication object into the SecurityContextHolder // Put an Authentication object into the SecurityContextHolder
Authentication originalAuth = new TestingAuthenticationToken("user", "password", "ROLE_A"); Authentication originalAuth = new TestingAuthenticationToken("user", "password", "ROLE_A");
SecurityContextHolder.getContext().setAuthentication(originalAuth); SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
given(strategy.getContext()).willReturn(new SecurityContextImpl(originalAuth));
AnonymousAuthenticationFilter filter = new AnonymousAuthenticationFilter("qwerty", "anonymousUsername", AnonymousAuthenticationFilter filter = new AnonymousAuthenticationFilter("qwerty", "anonymousUsername",
AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
filter.setSecurityContextHolderStrategy(strategy);
// Test // Test
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setRequestURI("x"); request.setRequestURI("x");
executeFilterInContainerSimulator(mock(FilterConfig.class), filter, request, new MockHttpServletResponse(), executeFilterInContainerSimulator(mock(FilterConfig.class), filter, request, new MockHttpServletResponse(),
new MockFilterChain(true)); new MockFilterChain(true));
// Ensure filter didn't change our original object // Ensure filter didn't change our original object
assertThat(SecurityContextHolder.getContext().getAuthentication()).isEqualTo(originalAuth); verify(strategy).getContext();
verify(strategy, never()).setContext(any());
} }
@Test @Test

View File

@ -17,20 +17,28 @@
package org.springframework.security.web.authentication; package org.springframework.security.web.authentication;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
/** /**
* Tests {@link UsernamePasswordAuthenticationFilter}. * Tests {@link UsernamePasswordAuthenticationFilter}.
@ -118,6 +126,22 @@ public class UsernamePasswordAuthenticationFilterTests {
.isThrownBy(() -> filter.attemptAuthentication(request, new MockHttpServletResponse())); .isThrownBy(() -> filter.attemptAuthentication(request, new MockHttpServletResponse()));
} }
@Test
public void testSecurityContextHolderStrategyUsed() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login");
request.setServletPath("/login");
request.addParameter(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_USERNAME_KEY, "rod");
request.addParameter(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_PASSWORD_KEY, "koala");
UsernamePasswordAuthenticationFilter filter = new UsernamePasswordAuthenticationFilter();
filter.setAuthenticationManager(createAuthenticationManager());
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
filter.setSecurityContextHolderStrategy(strategy);
filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain());
ArgumentCaptor<SecurityContext> captor = ArgumentCaptor.forClass(SecurityContext.class);
verify(strategy).setContext(captor.capture());
assertThat(captor.getValue().getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class);
}
/** /**
* SEC-571 * SEC-571
*/ */

View File

@ -28,6 +28,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession; import org.springframework.mock.web.MockHttpSession;
@ -38,6 +39,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.test.web.CodecTestUtils; import org.springframework.security.test.web.CodecTestUtils;
import org.springframework.security.web.authentication.WebAuthenticationDetails; import org.springframework.security.web.authentication.WebAuthenticationDetails;
import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository;
@ -51,6 +53,7 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
/** /**
@ -146,6 +149,19 @@ public class BasicAuthenticationFilterTests {
assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("rod"); assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("rod");
} }
@Test
public void testSecurityContextHolderStrategyUsed() throws Exception {
String token = "rod:koala";
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token.getBytes()));
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
this.filter.setSecurityContextHolderStrategy(strategy);
this.filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain());
ArgumentCaptor<SecurityContext> captor = ArgumentCaptor.forClass(SecurityContext.class);
verify(strategy).setContext(captor.capture());
assertThat(captor.getValue().getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class);
}
// gh-5586 // gh-5586
@Test @Test
public void doFilterWhenSchemeLowercaseThenCaseInsensitveMatchWorks() throws Exception { public void doFilterWhenSchemeLowercaseThenCaseInsensitveMatchWorks() throws Exception {

View File

@ -33,10 +33,12 @@ import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.context.SecurityContextImpl;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.verify;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
class SecurityContextHolderFilterTests { class SecurityContextHolderFilterTests {
@ -44,6 +46,9 @@ class SecurityContextHolderFilterTests {
@Mock @Mock
private SecurityContextRepository repository; private SecurityContextRepository repository;
@Mock
private SecurityContextHolderStrategy strategy;
@Mock @Mock
private HttpServletRequest request; private HttpServletRequest request;
@ -78,6 +83,21 @@ class SecurityContextHolderFilterTests {
assertThat(SecurityContextHolder.getContext()).isEqualTo(SecurityContextHolder.createEmptyContext()); assertThat(SecurityContextHolder.getContext()).isEqualTo(SecurityContextHolder.createEmptyContext());
} }
@Test
void doFilterThenSetsAndClearsSecurityContextHolderStrategy() throws Exception {
Authentication authentication = TestAuthentication.authenticatedUser();
SecurityContext expectedContext = new SecurityContextImpl(authentication);
given(this.repository.loadContext(this.requestArg.capture())).willReturn(() -> expectedContext);
FilterChain filterChain = (request, response) -> {
};
this.filter.setSecurityContextHolderStrategy(this.strategy);
this.filter.doFilter(this.request, this.response, filterChain);
verify(this.strategy).setContext(expectedContext);
verify(this.strategy).clearContext();
}
@Test @Test
void shouldNotFilterErrorDispatchWhenDefault() { void shouldNotFilterErrorDispatchWhenDefault() {
assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse(); assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse();