Use SecurityContextHolderStrategy for Messaging

Issue gh-11060
This commit is contained in:
Josh Cummings 2022-06-21 16:36:09 -06:00
parent 6e821382f1
commit b05fed8b9d
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
4 changed files with 108 additions and 24 deletions

View File

@ -32,6 +32,7 @@ import org.springframework.security.authorization.AuthorizationEventPublisher;
import org.springframework.security.authorization.AuthorizationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
/**
@ -42,14 +43,8 @@ import org.springframework.util.Assert;
*/
public final class AuthorizationChannelInterceptor implements ChannelInterceptor {
static final Supplier<Authentication> AUTHENTICATION_SUPPLIER = () -> {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication == null) {
throw new AuthenticationCredentialsNotFoundException(
"An Authentication object was not found in the SecurityContext");
}
return authentication;
};
private Supplier<Authentication> authentication = getAuthentication(
SecurityContextHolder.getContextHolderStrategy());
private final Log logger = LogFactory.getLog(this.getClass());
@ -71,8 +66,8 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
this.logger.debug(LogMessage.of(() -> "Authorizing message send"));
AuthorizationDecision decision = this.preSendAuthorizationManager.check(AUTHENTICATION_SUPPLIER, message);
this.eventPublisher.publishAuthorizationEvent(AUTHENTICATION_SUPPLIER, message, decision);
AuthorizationDecision decision = this.preSendAuthorizationManager.check(this.authentication, message);
this.eventPublisher.publishAuthorizationEvent(this.authentication, message, decision);
if (decision == null || !decision.isGranted()) { // default deny
this.logger.debug(LogMessage.of(() -> "Failed to authorize message with authorization manager "
+ this.preSendAuthorizationManager + " and decision " + decision));
@ -82,6 +77,14 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor
return message;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
this.authentication = getAuthentication(securityContextHolderStrategy);
}
/**
* Use this {@link AuthorizationEventPublisher} to publish the
* {@link AuthorizationManager} result.
@ -92,6 +95,17 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor
this.eventPublisher = eventPublisher;
}
private Supplier<Authentication> getAuthentication(SecurityContextHolderStrategy strategy) {
return () -> {
Authentication authentication = strategy.getContext().getAuthentication();
if (authentication == null) {
throw new AuthenticationCredentialsNotFoundException(
"An Authentication object was not found in the SecurityContext");
}
return authentication;
};
}
private static class NoopAuthorizationEventPublisher implements AuthorizationEventPublisher {
@Override

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,7 +29,9 @@ import org.springframework.messaging.handler.invocation.HandlerMethodArgumentRes
import org.springframework.security.core.Authentication;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
@ -85,6 +87,9 @@ import org.springframework.util.StringUtils;
*/
public final class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private ExpressionParser parser = new SpelExpressionParser();
@Override
@ -94,7 +99,7 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
@Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
if (authentication == null) {
return null;
}
@ -117,6 +122,17 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
return principal;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/**
* Obtains the specified {@link Annotation} on the specified {@link MethodParameter}.
* @param annotationClass the class of the {@link Annotation} to find on the

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
/**
@ -43,10 +44,13 @@ import org.springframework.util.Assert;
public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter
implements ExecutorChannelInterceptor {
private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>();
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private SecurityContext empty = this.securityContextHolderStrategy.createEmptyContext();
private final String authenticationHeaderName;
private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous",
@ -108,8 +112,13 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA
cleanup();
}
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) {
this.securityContextHolderStrategy = strategy;
this.empty = this.securityContextHolderStrategy.createEmptyContext();
}
private void setup(Message<?> message) {
SecurityContext currentContext = SecurityContextHolder.getContext();
SecurityContext currentContext = this.securityContextHolderStrategy.getContext();
Stack<SecurityContext> contextStack = originalContext.get();
if (contextStack == null) {
contextStack = new Stack<>();
@ -118,9 +127,9 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA
contextStack.push(currentContext);
Object user = message.getHeaders().get(this.authenticationHeaderName);
Authentication authentication = getAuthentication(user);
SecurityContext context = SecurityContextHolder.createEmptyContext();
SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
this.securityContextHolderStrategy.setContext(context);
}
private Authentication getAuthentication(Object user) {
@ -133,22 +142,22 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA
private void cleanup() {
Stack<SecurityContext> contextStack = originalContext.get();
if (contextStack == null || contextStack.isEmpty()) {
SecurityContextHolder.clearContext();
this.securityContextHolderStrategy.clearContext();
originalContext.remove();
return;
}
SecurityContext context = contextStack.pop();
try {
if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) {
SecurityContextHolder.clearContext();
if (SecurityContextChannelInterceptor.this.empty.equals(context)) {
this.securityContextHolderStrategy.clearContext();
originalContext.remove();
}
else {
SecurityContextHolder.setContext(context);
this.securityContextHolderStrategy.setContext(context);
}
}
catch (Throwable ex) {
SecurityContextHolder.clearContext();
this.securityContextHolderStrategy.clearContext();
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,9 +34,13 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
@ExtendWith(MockitoExtension.class)
public class SecurityContextChannelInterceptorTests {
@ -94,6 +98,17 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication);
}
@Test
public void preSendWhenCustomSecurityContextHolderStrategyThenUserSet() {
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
strategy.setContext(new SecurityContextImpl(this.authentication));
this.interceptor.setSecurityContextHolderStrategy(strategy);
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
this.interceptor.preSend(this.messageBuilder.build(), this.channel);
verify(strategy).getContext();
assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication);
}
@Test
public void setAnonymousAuthenticationNull() {
assertThatIllegalArgumentException().isThrownBy(() -> this.interceptor.setAnonymousAuthentication(null));
@ -143,6 +158,16 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
}
@Test
public void afterSendCompletionWhenCustomSecurityContextHolderStrategyThenNullAuthentication() {
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
strategy.setContext(new SecurityContextImpl(this.authentication));
this.interceptor.setSecurityContextHolderStrategy(strategy);
this.interceptor.afterSendCompletion(this.messageBuilder.build(), this.channel, true, null);
verify(strategy).clearContext();
assertThat(strategy.getContext().getAuthentication()).isNull();
}
@Test
public void beforeHandleUserSet() {
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
@ -150,6 +175,17 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication);
}
@Test
public void beforeHandleWhenCustomSecurityContextHolderStrategyThenUserSet() {
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
strategy.setContext(new SecurityContextImpl(this.authentication));
this.interceptor.setSecurityContextHolderStrategy(strategy);
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler);
verify(strategy).getContext();
assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication);
}
// SEC-2845
@Test
public void beforeHandleUserNotAuthentication() {
@ -178,6 +214,15 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
}
@Test
public void afterMessageHandledWhenCustomSecurityContextHolderStrategyThenUses() {
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
strategy.setContext(new SecurityContextImpl(this.authentication));
this.interceptor.setSecurityContextHolderStrategy(strategy);
this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null);
verify(strategy).clearContext();
}
// SEC-2829
@Test
public void restoresOriginalContext() {