Use SecurityContextHolderStrategy for Messaging

Issue gh-11060
This commit is contained in:
Josh Cummings 2022-06-21 16:36:09 -06:00
parent 275586be5f
commit 1e498df39b
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
4 changed files with 108 additions and 24 deletions

View File

@ -32,6 +32,7 @@ import org.springframework.security.authorization.AuthorizationEventPublisher;
import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.authorization.AuthorizationManager;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
@ -42,14 +43,8 @@ import org.springframework.util.Assert;
*/ */
public final class AuthorizationChannelInterceptor implements ChannelInterceptor { public final class AuthorizationChannelInterceptor implements ChannelInterceptor {
static final Supplier<Authentication> AUTHENTICATION_SUPPLIER = () -> { private Supplier<Authentication> authentication = getAuthentication(
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); SecurityContextHolder.getContextHolderStrategy());
if (authentication == null) {
throw new AuthenticationCredentialsNotFoundException(
"An Authentication object was not found in the SecurityContext");
}
return authentication;
};
private final Log logger = LogFactory.getLog(this.getClass()); private final Log logger = LogFactory.getLog(this.getClass());
@ -71,8 +66,8 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor
@Override @Override
public Message<?> preSend(Message<?> message, MessageChannel channel) { public Message<?> preSend(Message<?> message, MessageChannel channel) {
this.logger.debug(LogMessage.of(() -> "Authorizing message send")); this.logger.debug(LogMessage.of(() -> "Authorizing message send"));
AuthorizationDecision decision = this.preSendAuthorizationManager.check(AUTHENTICATION_SUPPLIER, message); AuthorizationDecision decision = this.preSendAuthorizationManager.check(this.authentication, message);
this.eventPublisher.publishAuthorizationEvent(AUTHENTICATION_SUPPLIER, message, decision); this.eventPublisher.publishAuthorizationEvent(this.authentication, message, decision);
if (decision == null || !decision.isGranted()) { // default deny if (decision == null || !decision.isGranted()) { // default deny
this.logger.debug(LogMessage.of(() -> "Failed to authorize message with authorization manager " this.logger.debug(LogMessage.of(() -> "Failed to authorize message with authorization manager "
+ this.preSendAuthorizationManager + " and decision " + decision)); + this.preSendAuthorizationManager + " and decision " + decision));
@ -82,6 +77,14 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor
return message; return message;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
this.authentication = getAuthentication(securityContextHolderStrategy);
}
/** /**
* Use this {@link AuthorizationEventPublisher} to publish the * Use this {@link AuthorizationEventPublisher} to publish the
* {@link AuthorizationManager} result. * {@link AuthorizationManager} result.
@ -92,6 +95,17 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor
this.eventPublisher = eventPublisher; this.eventPublisher = eventPublisher;
} }
private Supplier<Authentication> getAuthentication(SecurityContextHolderStrategy strategy) {
return () -> {
Authentication authentication = strategy.getContext().getAuthentication();
if (authentication == null) {
throw new AuthenticationCredentialsNotFoundException(
"An Authentication object was not found in the SecurityContext");
}
return authentication;
};
}
private static class NoopAuthorizationEventPublisher implements AuthorizationEventPublisher { private static class NoopAuthorizationEventPublisher implements AuthorizationEventPublisher {
@Override @Override

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,7 +29,9 @@ import org.springframework.messaging.handler.invocation.HandlerMethodArgumentRes
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.annotation.AuthenticationPrincipal; import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -85,6 +87,9 @@ import org.springframework.util.StringUtils;
*/ */
public final class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver { public final class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private ExpressionParser parser = new SpelExpressionParser(); private ExpressionParser parser = new SpelExpressionParser();
@Override @Override
@ -94,7 +99,7 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
@Override @Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) { public Object resolveArgument(MethodParameter parameter, Message<?> message) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
if (authentication == null) { if (authentication == null) {
return null; return null;
} }
@ -117,6 +122,17 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
return principal; return principal;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/** /**
* Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}.
* @param annotationClass the class of the {@link Annotation} to find on the * @param annotationClass the class of the {@link Annotation} to find on the

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2016 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
@ -42,10 +43,13 @@ import org.springframework.util.Assert;
*/ */
public final class SecurityContextChannelInterceptor implements ExecutorChannelInterceptor, ChannelInterceptor { public final class SecurityContextChannelInterceptor implements ExecutorChannelInterceptor, ChannelInterceptor {
private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>(); private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>();
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private SecurityContext empty = this.securityContextHolderStrategy.createEmptyContext();
private final String authenticationHeaderName; private final String authenticationHeaderName;
private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous",
@ -107,8 +111,13 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI
cleanup(); cleanup();
} }
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) {
this.securityContextHolderStrategy = strategy;
this.empty = this.securityContextHolderStrategy.createEmptyContext();
}
private void setup(Message<?> message) { private void setup(Message<?> message) {
SecurityContext currentContext = SecurityContextHolder.getContext(); SecurityContext currentContext = this.securityContextHolderStrategy.getContext();
Stack<SecurityContext> contextStack = originalContext.get(); Stack<SecurityContext> contextStack = originalContext.get();
if (contextStack == null) { if (contextStack == null) {
contextStack = new Stack<>(); contextStack = new Stack<>();
@ -117,9 +126,9 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI
contextStack.push(currentContext); contextStack.push(currentContext);
Object user = message.getHeaders().get(this.authenticationHeaderName); Object user = message.getHeaders().get(this.authenticationHeaderName);
Authentication authentication = getAuthentication(user); Authentication authentication = getAuthentication(user);
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authentication); context.setAuthentication(authentication);
SecurityContextHolder.setContext(context); this.securityContextHolderStrategy.setContext(context);
} }
private Authentication getAuthentication(Object user) { private Authentication getAuthentication(Object user) {
@ -132,22 +141,22 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI
private void cleanup() { private void cleanup() {
Stack<SecurityContext> contextStack = originalContext.get(); Stack<SecurityContext> contextStack = originalContext.get();
if (contextStack == null || contextStack.isEmpty()) { if (contextStack == null || contextStack.isEmpty()) {
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
originalContext.remove(); originalContext.remove();
return; return;
} }
SecurityContext context = contextStack.pop(); SecurityContext context = contextStack.pop();
try { try {
if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) { if (SecurityContextChannelInterceptor.this.empty.equals(context)) {
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
originalContext.remove(); originalContext.remove();
} }
else { else {
SecurityContextHolder.setContext(context); this.securityContextHolderStrategy.setContext(context);
} }
} }
catch (Throwable ex) { catch (Throwable ex) {
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2016 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -34,9 +34,13 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
public class SecurityContextChannelInterceptorTests { public class SecurityContextChannelInterceptorTests {
@ -94,6 +98,17 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication);
} }
@Test
public void preSendWhenCustomSecurityContextHolderStrategyThenUserSet() {
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
strategy.setContext(new SecurityContextImpl(this.authentication));
this.interceptor.setSecurityContextHolderStrategy(strategy);
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
this.interceptor.preSend(this.messageBuilder.build(), this.channel);
verify(strategy).getContext();
assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication);
}
@Test @Test
public void setAnonymousAuthenticationNull() { public void setAnonymousAuthenticationNull() {
assertThatIllegalArgumentException().isThrownBy(() -> this.interceptor.setAnonymousAuthentication(null)); assertThatIllegalArgumentException().isThrownBy(() -> this.interceptor.setAnonymousAuthentication(null));
@ -143,6 +158,16 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
} }
@Test
public void afterSendCompletionWhenCustomSecurityContextHolderStrategyThenNullAuthentication() {
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
strategy.setContext(new SecurityContextImpl(this.authentication));
this.interceptor.setSecurityContextHolderStrategy(strategy);
this.interceptor.afterSendCompletion(this.messageBuilder.build(), this.channel, true, null);
verify(strategy).clearContext();
assertThat(strategy.getContext().getAuthentication()).isNull();
}
@Test @Test
public void beforeHandleUserSet() { public void beforeHandleUserSet() {
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
@ -150,6 +175,17 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication);
} }
@Test
public void beforeHandleWhenCustomSecurityContextHolderStrategyThenUserSet() {
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
strategy.setContext(new SecurityContextImpl(this.authentication));
this.interceptor.setSecurityContextHolderStrategy(strategy);
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler);
verify(strategy).getContext();
assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication);
}
// SEC-2845 // SEC-2845
@Test @Test
public void beforeHandleUserNotAuthentication() { public void beforeHandleUserNotAuthentication() {
@ -178,6 +214,15 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
} }
@Test
public void afterMessageHandledWhenCustomSecurityContextHolderStrategyThenUses() {
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
strategy.setContext(new SecurityContextImpl(this.authentication));
this.interceptor.setSecurityContextHolderStrategy(strategy);
this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null);
verify(strategy).clearContext();
}
// SEC-2829 // SEC-2829
@Test @Test
public void restoresOriginalContext() { public void restoresOriginalContext() {