diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/AuthorizationChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/AuthorizationChannelInterceptor.java index 5ec879f9cc..c61e3b660e 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/AuthorizationChannelInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/AuthorizationChannelInterceptor.java @@ -32,6 +32,7 @@ import org.springframework.security.authorization.AuthorizationEventPublisher; import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; /** @@ -42,14 +43,8 @@ import org.springframework.util.Assert; */ public final class AuthorizationChannelInterceptor implements ChannelInterceptor { - static final Supplier AUTHENTICATION_SUPPLIER = () -> { - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - if (authentication == null) { - throw new AuthenticationCredentialsNotFoundException( - "An Authentication object was not found in the SecurityContext"); - } - return authentication; - }; + private Supplier authentication = getAuthentication( + SecurityContextHolder.getContextHolderStrategy()); private final Log logger = LogFactory.getLog(this.getClass()); @@ -71,8 +66,8 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor @Override public Message preSend(Message message, MessageChannel channel) { this.logger.debug(LogMessage.of(() -> "Authorizing message send")); - AuthorizationDecision decision = this.preSendAuthorizationManager.check(AUTHENTICATION_SUPPLIER, message); - this.eventPublisher.publishAuthorizationEvent(AUTHENTICATION_SUPPLIER, message, decision); + AuthorizationDecision decision = this.preSendAuthorizationManager.check(this.authentication, message); + this.eventPublisher.publishAuthorizationEvent(this.authentication, message, decision); if (decision == null || !decision.isGranted()) { // default deny this.logger.debug(LogMessage.of(() -> "Failed to authorize message with authorization manager " + this.preSendAuthorizationManager + " and decision " + decision)); @@ -82,6 +77,14 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor return message; } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + this.authentication = getAuthentication(securityContextHolderStrategy); + } + /** * Use this {@link AuthorizationEventPublisher} to publish the * {@link AuthorizationManager} result. @@ -92,6 +95,17 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor this.eventPublisher = eventPublisher; } + private Supplier getAuthentication(SecurityContextHolderStrategy strategy) { + return () -> { + Authentication authentication = strategy.getContext().getAuthentication(); + if (authentication == null) { + throw new AuthenticationCredentialsNotFoundException( + "An Authentication object was not found in the SecurityContext"); + } + return authentication; + }; + } + private static class NoopAuthorizationEventPublisher implements AuthorizationEventPublisher { @Override diff --git a/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java index aaf99c16ee..09cbddfefc 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java +++ b/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,9 @@ import org.springframework.messaging.handler.invocation.HandlerMethodArgumentRes import org.springframework.security.core.Authentication; import org.springframework.security.core.annotation.AuthenticationPrincipal; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.stereotype.Controller; +import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; @@ -85,6 +87,9 @@ import org.springframework.util.StringUtils; */ public final class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver { + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private ExpressionParser parser = new SpelExpressionParser(); @Override @@ -94,7 +99,7 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet @Override public Object resolveArgument(MethodParameter parameter, Message message) { - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); if (authentication == null) { return null; } @@ -117,6 +122,17 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet return principal; } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + /** * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. * @param annotationClass the class of the {@link Annotation} to find on the diff --git a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java index b7f959c499..153d25578d 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; /** @@ -42,10 +43,13 @@ import org.springframework.util.Assert; */ public final class SecurityContextChannelInterceptor implements ExecutorChannelInterceptor, ChannelInterceptor { - private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext(); - private static final ThreadLocal> originalContext = new ThreadLocal<>(); + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + + private SecurityContext empty = this.securityContextHolderStrategy.createEmptyContext(); + private final String authenticationHeaderName; private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", @@ -107,8 +111,13 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI cleanup(); } + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) { + this.securityContextHolderStrategy = strategy; + this.empty = this.securityContextHolderStrategy.createEmptyContext(); + } + private void setup(Message message) { - SecurityContext currentContext = SecurityContextHolder.getContext(); + SecurityContext currentContext = this.securityContextHolderStrategy.getContext(); Stack contextStack = originalContext.get(); if (contextStack == null) { contextStack = new Stack<>(); @@ -117,9 +126,9 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI contextStack.push(currentContext); Object user = message.getHeaders().get(this.authenticationHeaderName); Authentication authentication = getAuthentication(user); - SecurityContext context = SecurityContextHolder.createEmptyContext(); + SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); context.setAuthentication(authentication); - SecurityContextHolder.setContext(context); + this.securityContextHolderStrategy.setContext(context); } private Authentication getAuthentication(Object user) { @@ -132,22 +141,22 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI private void cleanup() { Stack contextStack = originalContext.get(); if (contextStack == null || contextStack.isEmpty()) { - SecurityContextHolder.clearContext(); + this.securityContextHolderStrategy.clearContext(); originalContext.remove(); return; } SecurityContext context = contextStack.pop(); try { - if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) { - SecurityContextHolder.clearContext(); + if (SecurityContextChannelInterceptor.this.empty.equals(context)) { + this.securityContextHolderStrategy.clearContext(); originalContext.remove(); } else { - SecurityContextHolder.setContext(context); + this.securityContextHolderStrategy.setContext(context); } } catch (Throwable ex) { - SecurityContextHolder.clearContext(); + this.securityContextHolderStrategy.clearContext(); } } diff --git a/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java index 774c13530a..628af60558 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,9 +34,13 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; @ExtendWith(MockitoExtension.class) public class SecurityContextChannelInterceptorTests { @@ -94,6 +98,17 @@ public class SecurityContextChannelInterceptorTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); } + @Test + public void preSendWhenCustomSecurityContextHolderStrategyThenUserSet() { + SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); + strategy.setContext(new SecurityContextImpl(this.authentication)); + this.interceptor.setSecurityContextHolderStrategy(strategy); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + this.interceptor.preSend(this.messageBuilder.build(), this.channel); + verify(strategy).getContext(); + assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication); + } + @Test public void setAnonymousAuthenticationNull() { assertThatIllegalArgumentException().isThrownBy(() -> this.interceptor.setAnonymousAuthentication(null)); @@ -143,6 +158,16 @@ public class SecurityContextChannelInterceptorTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); } + @Test + public void afterSendCompletionWhenCustomSecurityContextHolderStrategyThenNullAuthentication() { + SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); + strategy.setContext(new SecurityContextImpl(this.authentication)); + this.interceptor.setSecurityContextHolderStrategy(strategy); + this.interceptor.afterSendCompletion(this.messageBuilder.build(), this.channel, true, null); + verify(strategy).clearContext(); + assertThat(strategy.getContext().getAuthentication()).isNull(); + } + @Test public void beforeHandleUserSet() { this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); @@ -150,6 +175,17 @@ public class SecurityContextChannelInterceptorTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); } + @Test + public void beforeHandleWhenCustomSecurityContextHolderStrategyThenUserSet() { + SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); + strategy.setContext(new SecurityContextImpl(this.authentication)); + this.interceptor.setSecurityContextHolderStrategy(strategy); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); + verify(strategy).getContext(); + assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication); + } + // SEC-2845 @Test public void beforeHandleUserNotAuthentication() { @@ -178,6 +214,15 @@ public class SecurityContextChannelInterceptorTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); } + @Test + public void afterMessageHandledWhenCustomSecurityContextHolderStrategyThenUses() { + SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); + strategy.setContext(new SecurityContextImpl(this.authentication)); + this.interceptor.setSecurityContextHolderStrategy(strategy); + this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); + verify(strategy).clearContext(); + } + // SEC-2829 @Test public void restoresOriginalContext() {